blob: 77713377c6be8b095a1348232dec46cab68f8ea1 [file] [log] [blame]
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<title>pyspark.ml.tuning &#8212; PySpark 3.5.3 documentation</title>
<link href="../../../_static/styles/theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=1999514e3f237ded88cf" rel="stylesheet">
<link rel="stylesheet"
href="../../../_static/vendor/fontawesome/5.13.0/css/all.min.css">
<link rel="preload" as="font" type="font/woff2" crossorigin
href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-solid-900.woff2">
<link rel="preload" as="font" type="font/woff2" crossorigin
href="../../../_static/vendor/fontawesome/5.13.0/webfonts/fa-brands-400.woff2">
<link rel="stylesheet" href="../../../_static/styles/pydata-sphinx-theme.css" type="text/css" />
<link rel="stylesheet" href="../../../_static/pygments.css" type="text/css" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" />
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf">
<script id="documentation_options" data-url_root="../../../" src="../../../_static/documentation_options.js"></script>
<script src="../../../_static/jquery.js"></script>
<script src="../../../_static/underscore.js"></script>
<script src="../../../_static/doctools.js"></script>
<script src="../../../_static/language_data.js"></script>
<script src="../../../_static/clipboard.min.js"></script>
<script src="../../../_static/copybutton.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script async="async" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
<script type="text/x-mathjax-config">MathJax.Hub.Config({"tex2jax": {"inlineMath": [["$", "$"], ["\\(", "\\)"]], "processEscapes": true, "ignoreClass": "document", "processClass": "math|output_area"}})</script>
<link rel="canonical" href="https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/tuning.html" />
<link rel="search" title="Search" href="../../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="docsearch:language" content="None">
<!-- Google Analytics -->
</head>
<body data-spy="scroll" data-target="#bd-toc-nav" data-offset="80">
<div class="container-fluid" id="banner"></div>
<nav class="navbar navbar-light navbar-expand-lg bg-light fixed-top bd-navbar" id="navbar-main"><div class="container-xl">
<div id="navbar-start">
<a class="navbar-brand" href="../../../index.html">
<img src="../../../_static/spark-logo-reverse.png" class="logo" alt="logo">
</a>
</div>
<button class="navbar-toggler" type="button" data-toggle="collapse" data-target="#navbar-collapsible" aria-controls="navbar-collapsible" aria-expanded="false" aria-label="Toggle navigation">
<span class="navbar-toggler-icon"></span>
</button>
<div id="navbar-collapsible" class="col-lg-9 collapse navbar-collapse">
<div id="navbar-center" class="mr-auto">
<div class="navbar-center-item">
<ul id="navbar-main-elements" class="navbar-nav">
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../index.html">
Overview
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../development/index.html">
Development
</a>
</li>
<li class="toctree-l1 nav-item">
<a class="reference internal nav-link" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</div>
</div>
<div id="navbar-end">
<div class="navbar-end-item">
<!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
3.5.3
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/ml/tuning.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script>
</div>
</div>
</div>
</div>
</nav>
<div class="container-xl">
<div class="row">
<!-- Only show if we have sidebars configured, else just a small margin -->
<div class="col-12 col-md-3 bd-sidebar">
<div class="sidebar-start-items"><form class="bd-search d-flex align-items-center" action="../../../search.html" method="get">
<i class="icon fas fa-search"></i>
<input type="search" class="form-control" name="q" id="search-input" placeholder="Search the docs ..." aria-label="Search the docs ..." autocomplete="off" >
</form><nav class="bd-links" id="bd-docs-nav" aria-label="Main navigation">
<div class="bd-toc-item active">
</div>
</nav>
</div>
<div class="sidebar-end-items">
</div>
</div>
<div class="d-none d-xl-block col-xl-2 bd-toc">
</div>
<main class="col-12 col-md-9 col-xl-7 py-md-5 pl-md-5 pr-md-4 bd-content" role="main">
<div>
<h1>Source code for pyspark.ml.tuning</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span>
<span class="c1"># contributor license agreements. See the NOTICE file distributed with</span>
<span class="c1"># this work for additional information regarding copyright ownership.</span>
<span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span>
<span class="c1"># (the &quot;License&quot;); you may not use this file except in compliance with</span>
<span class="c1"># the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">itertools</span>
<span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">Any</span><span class="p">,</span>
<span class="n">Callable</span><span class="p">,</span>
<span class="n">Dict</span><span class="p">,</span>
<span class="n">Iterable</span><span class="p">,</span>
<span class="n">List</span><span class="p">,</span>
<span class="n">Optional</span><span class="p">,</span>
<span class="n">Sequence</span><span class="p">,</span>
<span class="n">Tuple</span><span class="p">,</span>
<span class="n">Type</span><span class="p">,</span>
<span class="n">Union</span><span class="p">,</span>
<span class="n">cast</span><span class="p">,</span>
<span class="n">overload</span><span class="p">,</span>
<span class="n">TYPE_CHECKING</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span>
<span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">keyword_only</span><span class="p">,</span> <span class="n">since</span><span class="p">,</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">inheritable_thread_target</span>
<span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">,</span> <span class="n">Model</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.common</span> <span class="kn">import</span> <span class="n">inherit_doc</span><span class="p">,</span> <span class="n">_py2java</span><span class="p">,</span> <span class="n">_java2py</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">Evaluator</span><span class="p">,</span> <span class="n">JavaEvaluator</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.param</span> <span class="kn">import</span> <span class="n">Params</span><span class="p">,</span> <span class="n">Param</span><span class="p">,</span> <span class="n">TypeConverters</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.param.shared</span> <span class="kn">import</span> <span class="n">HasCollectSubModels</span><span class="p">,</span> <span class="n">HasParallelism</span><span class="p">,</span> <span class="n">HasSeed</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.util</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">DefaultParamsReader</span><span class="p">,</span>
<span class="n">DefaultParamsWriter</span><span class="p">,</span>
<span class="n">MetaAlgorithmReadWrite</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">,</span>
<span class="n">MLReader</span><span class="p">,</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="n">MLWriter</span><span class="p">,</span>
<span class="n">JavaMLReader</span><span class="p">,</span>
<span class="n">JavaMLWriter</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.wrapper</span> <span class="kn">import</span> <span class="n">JavaParams</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">,</span> <span class="n">JavaWrapper</span>
<span class="kn">from</span> <span class="nn">pyspark.sql.functions</span> <span class="kn">import</span> <span class="n">col</span><span class="p">,</span> <span class="n">lit</span><span class="p">,</span> <span class="n">rand</span><span class="p">,</span> <span class="n">UserDefinedFunction</span>
<span class="kn">from</span> <span class="nn">pyspark.sql.types</span> <span class="kn">import</span> <span class="n">BooleanType</span>
<span class="kn">from</span> <span class="nn">pyspark.sql.dataframe</span> <span class="kn">import</span> <span class="n">DataFrame</span>
<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">pyspark.ml._typing</span> <span class="kn">import</span> <span class="n">ParamMap</span>
<span class="kn">from</span> <span class="nn">py4j.java_gateway</span> <span class="kn">import</span> <span class="n">JavaObject</span>
<span class="kn">from</span> <span class="nn">py4j.java_collections</span> <span class="kn">import</span> <span class="n">JavaArray</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span>
<span class="s2">&quot;ParamGridBuilder&quot;</span><span class="p">,</span>
<span class="s2">&quot;CrossValidator&quot;</span><span class="p">,</span>
<span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">,</span>
<span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">def</span> <span class="nf">_parallelFitTasks</span><span class="p">(</span>
<span class="n">est</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span>
<span class="n">train</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">,</span>
<span class="n">eva</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">,</span>
<span class="n">validation</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">,</span>
<span class="n">epm</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">],</span>
<span class="n">collectSubModel</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Callable</span><span class="p">[[],</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">]]]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a list of callables which can be called from different threads to fit and evaluate</span>
<span class="sd"> an estimator in parallel. Each callable returns an `(index, metric)` pair.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> est : :py:class:`pyspark.ml.baseEstimator`</span>
<span class="sd"> he estimator to be fit.</span>
<span class="sd"> train : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> DataFrame, training data set, used for fitting.</span>
<span class="sd"> eva : :py:class:`pyspark.ml.evaluation.Evaluator`</span>
<span class="sd"> used to compute `metric`</span>
<span class="sd"> validation : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> DataFrame, validation data set, used for evaluation.</span>
<span class="sd"> epm : :py:class:`collections.abc.Sequence`</span>
<span class="sd"> Sequence of ParamMap, params maps to be used during fitting &amp; evaluation.</span>
<span class="sd"> collectSubModel : bool</span>
<span class="sd"> Whether to collect sub model.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> tuple</span>
<span class="sd"> (int, float, subModel), an index into `epm` and the associated metric value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">modelIter</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fitMultiple</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">epm</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">singleTask</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">]:</span>
<span class="n">index</span><span class="p">,</span> <span class="n">model</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">modelIter</span><span class="p">)</span>
<span class="c1"># TODO: duplicate evaluator to take extra params from input</span>
<span class="c1"># Note: Supporting tuning params in evaluator need update method</span>
<span class="c1"># `MetaAlgorithmReadWrite.getAllNestedStages`, make it return</span>
<span class="c1"># all nested stages and evaluators</span>
<span class="n">metric</span> <span class="o">=</span> <span class="n">eva</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">index</span><span class="p">]))</span>
<span class="k">return</span> <span class="n">index</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">model</span> <span class="k">if</span> <span class="n">collectSubModel</span> <span class="k">else</span> <span class="kc">None</span>
<span class="k">return</span> <span class="p">[</span><span class="n">singleTask</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span>
<div class="viewcode-block" id="ParamGridBuilder"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder">[docs]</a><span class="k">class</span> <span class="nc">ParamGridBuilder</span><span class="p">:</span>
<span class="w"> </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Builder for a param grid used in grid search-based model selection.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.classification import LogisticRegression</span>
<span class="sd"> &gt;&gt;&gt; lr = LogisticRegression()</span>
<span class="sd"> &gt;&gt;&gt; output = ParamGridBuilder() \</span>
<span class="sd"> ... .baseOn({lr.labelCol: &#39;l&#39;}) \</span>
<span class="sd"> ... .baseOn([lr.predictionCol, &#39;p&#39;]) \</span>
<span class="sd"> ... .addGrid(lr.regParam, [1.0, 2.0]) \</span>
<span class="sd"> ... .addGrid(lr.maxIter, [1, 5]) \</span>
<span class="sd"> ... .build()</span>
<span class="sd"> &gt;&gt;&gt; expected = [</span>
<span class="sd"> ... {lr.regParam: 1.0, lr.maxIter: 1, lr.labelCol: &#39;l&#39;, lr.predictionCol: &#39;p&#39;},</span>
<span class="sd"> ... {lr.regParam: 2.0, lr.maxIter: 1, lr.labelCol: &#39;l&#39;, lr.predictionCol: &#39;p&#39;},</span>
<span class="sd"> ... {lr.regParam: 1.0, lr.maxIter: 5, lr.labelCol: &#39;l&#39;, lr.predictionCol: &#39;p&#39;},</span>
<span class="sd"> ... {lr.regParam: 2.0, lr.maxIter: 5, lr.labelCol: &#39;l&#39;, lr.predictionCol: &#39;p&#39;}]</span>
<span class="sd"> &gt;&gt;&gt; len(output) == len(expected)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; all([m in expected for m in output])</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="p">:</span> <span class="s2">&quot;ParamMap&quot;</span> <span class="o">=</span> <span class="p">{}</span>
<div class="viewcode-block" id="ParamGridBuilder.addGrid"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.addGrid">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">addGrid</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">param</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Any</span><span class="p">],</span> <span class="n">values</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;ParamGridBuilder&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the given parameters in this grid to fixed values.</span>
<span class="sd"> param must be an instance of Param associated with an instance of Params</span>
<span class="sd"> (such as Estimator or Transformer).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">Param</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;param must be an instance of Param&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span></div>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">__args</span><span class="p">:</span> <span class="s2">&quot;ParamMap&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;ParamGridBuilder&quot;</span><span class="p">:</span>
<span class="o">...</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;ParamGridBuilder&quot;</span><span class="p">:</span>
<span class="o">...</span>
<div class="viewcode-block" id="ParamGridBuilder.baseOn"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.baseOn">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">baseOn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="s2">&quot;ParamGridBuilder&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the given parameters in this grid to fixed values.</span>
<span class="sd"> Accepts either a parameter dictionary or a list of (parameter, value) pairs.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="nb">dict</span><span class="p">):</span>
<span class="bp">self</span><span class="o">.</span><span class="n">baseOn</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">items</span><span class="p">())</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">for</span> <span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="ow">in</span> <span class="n">args</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">addGrid</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="p">[</span><span class="n">value</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span></div>
<div class="viewcode-block" id="ParamGridBuilder.build"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.ParamGridBuilder.html#pyspark.ml.tuning.ParamGridBuilder.build">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Builds and returns all combinations of parameters specified</span>
<span class="sd"> by the param grid.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">keys</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="o">.</span><span class="n">keys</span><span class="p">()</span>
<span class="n">grid_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_param_grid</span><span class="o">.</span><span class="n">values</span><span class="p">()</span>
<span class="k">def</span> <span class="nf">to_key_value_pairs</span><span class="p">(</span>
<span class="n">keys</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Param</span><span class="p">],</span> <span class="n">values</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="n">Any</span><span class="p">]</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Sequence</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">Param</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]:</span>
<span class="k">return</span> <span class="p">[(</span><span class="n">key</span><span class="p">,</span> <span class="n">key</span><span class="o">.</span><span class="n">typeConverter</span><span class="p">(</span><span class="n">value</span><span class="p">))</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">values</span><span class="p">)]</span>
<span class="k">return</span> <span class="p">[</span><span class="nb">dict</span><span class="p">(</span><span class="n">to_key_value_pairs</span><span class="p">(</span><span class="n">keys</span><span class="p">,</span> <span class="n">prod</span><span class="p">))</span> <span class="k">for</span> <span class="n">prod</span> <span class="ow">in</span> <span class="n">itertools</span><span class="o">.</span><span class="n">product</span><span class="p">(</span><span class="o">*</span><span class="n">grid_values</span><span class="p">)]</span></div></div>
<span class="k">class</span> <span class="nc">_ValidatorParams</span><span class="p">(</span><span class="n">HasSeed</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Common params for TrainValidationSplit and CrossValidator.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">&quot;estimator&quot;</span><span class="p">,</span> <span class="s2">&quot;estimator to be cross-validated&quot;</span>
<span class="p">)</span>
<span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">,</span> <span class="s2">&quot;estimator param maps&quot;</span>
<span class="p">)</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;evaluator&quot;</span><span class="p">,</span>
<span class="s2">&quot;evaluator used to select hyper-parameters that maximize the validator metric&quot;</span><span class="p">,</span>
<span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Estimator</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of estimator or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of estimatorParamMaps or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Evaluator</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of evaluator or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java_impl</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Estimator</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">],</span> <span class="n">Evaluator</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Load information from java_stage to the instance.</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Estimator</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Evaluator</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">estimator</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">):</span>
<span class="n">epms</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">estimator</span><span class="o">.</span><span class="n">_transfer_param_map_from_java</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span>
<span class="k">for</span> <span class="n">epm</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span>
<span class="p">]</span>
<span class="k">elif</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">):</span>
<span class="c1"># Meta estimator such as Pipeline, OneVsRest</span>
<span class="n">epms</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">meta_estimator_transfer_param_maps_from_java</span><span class="p">(</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported estimator used in tuning: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">estimator</span><span class="p">))</span>
<span class="k">return</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span>
<span class="k">def</span> <span class="nf">_to_java_impl</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="s2">&quot;JavaObject&quot;</span><span class="p">,</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">,</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">gateway</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span>
<span class="k">assert</span> <span class="n">gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="bp">cls</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">param</span><span class="o">.</span><span class="n">ParamMap</span>
<span class="n">estimator</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">estimator</span><span class="p">,</span> <span class="n">JavaEstimator</span><span class="p">):</span>
<span class="n">java_epms</span> <span class="o">=</span> <span class="n">gateway</span><span class="o">.</span><span class="n">new_array</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()))</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">epm</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()):</span>
<span class="n">java_epms</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">estimator</span><span class="o">.</span><span class="n">_transfer_param_map_to_java</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span>
<span class="k">elif</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">):</span>
<span class="c1"># Meta estimator such as Pipeline, OneVsRest</span>
<span class="n">java_epms</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">meta_estimator_transfer_param_maps_to_java</span><span class="p">(</span>
<span class="n">estimator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported estimator used in tuning: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">estimator</span><span class="p">))</span>
<span class="n">java_estimator</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaEstimator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span>
<span class="n">java_evaluator</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaEvaluator</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span>
<span class="k">return</span> <span class="n">java_estimator</span><span class="p">,</span> <span class="n">java_epms</span><span class="p">,</span> <span class="n">java_evaluator</span>
<span class="k">class</span> <span class="nc">_ValidatorSharedReadWrite</span><span class="p">:</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">meta_estimator_transfer_param_maps_to_java</span><span class="p">(</span>
<span class="n">pyEstimator</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">pyParamMaps</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaArray&quot;</span><span class="p">:</span>
<span class="n">pyStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">pyEstimator</span><span class="p">)</span>
<span class="n">stagePairs</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">stage</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()),</span> <span class="n">pyStages</span><span class="p">))</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="p">(</span>
<span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="p">)</span>
<span class="n">paramMapCls</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">param</span><span class="o">.</span><span class="n">ParamMap</span>
<span class="n">javaParamMaps</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">new_array</span><span class="p">(</span><span class="n">paramMapCls</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">pyParamMaps</span><span class="p">))</span>
<span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">pyParamMap</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">pyParamMaps</span><span class="p">):</span>
<span class="n">javaParamMap</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">&quot;org.apache.spark.ml.param.ParamMap&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">pyParam</span><span class="p">,</span> <span class="n">pyValue</span> <span class="ow">in</span> <span class="n">pyParamMap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">javaParam</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">pyStage</span><span class="p">,</span> <span class="n">javaStage</span> <span class="ow">in</span> <span class="n">stagePairs</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">_testOwnParam</span><span class="p">(</span><span class="n">pyParam</span><span class="o">.</span><span class="n">parent</span><span class="p">,</span> <span class="n">pyParam</span><span class="o">.</span><span class="n">name</span><span class="p">):</span>
<span class="n">javaParam</span> <span class="o">=</span> <span class="n">javaStage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">pyParam</span><span class="o">.</span><span class="n">name</span><span class="p">)</span>
<span class="k">break</span>
<span class="k">if</span> <span class="n">javaParam</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Resolve param in estimatorParamMaps failed: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">pyParam</span><span class="p">))</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">pyValue</span><span class="p">,</span> <span class="n">Params</span><span class="p">)</span> <span class="ow">and</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">pyValue</span><span class="p">,</span> <span class="s2">&quot;_to_java&quot;</span><span class="p">):</span>
<span class="n">javaValue</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">pyValue</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">javaValue</span> <span class="o">=</span> <span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">pyValue</span><span class="p">)</span>
<span class="n">pair</span> <span class="o">=</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">javaValue</span><span class="p">)</span>
<span class="n">javaParamMap</span><span class="o">.</span><span class="n">put</span><span class="p">([</span><span class="n">pair</span><span class="p">])</span>
<span class="n">javaParamMaps</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">javaParamMap</span>
<span class="k">return</span> <span class="n">javaParamMaps</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">meta_estimator_transfer_param_maps_from_java</span><span class="p">(</span>
<span class="n">pyEstimator</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">javaParamMaps</span><span class="p">:</span> <span class="s2">&quot;JavaArray&quot;</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]:</span>
<span class="n">pyStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">pyEstimator</span><span class="p">)</span>
<span class="n">stagePairs</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">stage</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()),</span> <span class="n">pyStages</span><span class="p">))</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">pyParamMaps</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">javaParamMap</span> <span class="ow">in</span> <span class="n">javaParamMaps</span><span class="p">:</span>
<span class="n">pyParamMap</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="k">for</span> <span class="n">javaPair</span> <span class="ow">in</span> <span class="n">javaParamMap</span><span class="o">.</span><span class="n">toList</span><span class="p">():</span>
<span class="n">javaParam</span> <span class="o">=</span> <span class="n">javaPair</span><span class="o">.</span><span class="n">param</span><span class="p">()</span>
<span class="n">pyParam</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">for</span> <span class="n">pyStage</span><span class="p">,</span> <span class="n">javaStage</span> <span class="ow">in</span> <span class="n">stagePairs</span><span class="p">:</span>
<span class="k">if</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">_testOwnParam</span><span class="p">(</span><span class="n">javaParam</span><span class="o">.</span><span class="n">parent</span><span class="p">(),</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">()):</span>
<span class="n">pyParam</span> <span class="o">=</span> <span class="n">pyStage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">())</span>
<span class="k">if</span> <span class="n">pyParam</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Resolve param in estimatorParamMaps failed: &quot;</span>
<span class="o">+</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">parent</span><span class="p">()</span>
<span class="o">+</span> <span class="s2">&quot;.&quot;</span>
<span class="o">+</span> <span class="n">javaParam</span><span class="o">.</span><span class="n">name</span><span class="p">()</span>
<span class="p">)</span>
<span class="n">javaValue</span> <span class="o">=</span> <span class="n">javaPair</span><span class="o">.</span><span class="n">value</span><span class="p">()</span>
<span class="n">pyValue</span><span class="p">:</span> <span class="n">Any</span>
<span class="k">if</span> <span class="n">sc</span><span class="o">.</span><span class="n">_jvm</span><span class="o">.</span><span class="n">Class</span><span class="o">.</span><span class="n">forName</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.util.DefaultParamsWritable&quot;</span>
<span class="p">)</span><span class="o">.</span><span class="n">isInstance</span><span class="p">(</span><span class="n">javaValue</span><span class="p">):</span>
<span class="n">pyValue</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">javaValue</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">pyValue</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">javaValue</span><span class="p">)</span>
<span class="n">pyParamMap</span><span class="p">[</span><span class="n">pyParam</span><span class="p">]</span> <span class="o">=</span> <span class="n">pyValue</span>
<span class="n">pyParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">pyParamMap</span><span class="p">)</span>
<span class="k">return</span> <span class="n">pyParamMaps</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">is_java_convertible</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="n">allNestedStages</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getAllNestedStages</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span>
<span class="n">evaluator_convertible</span> <span class="o">=</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">(),</span> <span class="n">JavaParams</span><span class="p">)</span>
<span class="n">estimator_convertible</span> <span class="o">=</span> <span class="nb">all</span><span class="p">(</span><span class="nb">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">stage</span><span class="p">:</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">stage</span><span class="p">,</span> <span class="s2">&quot;_to_java&quot;</span><span class="p">),</span> <span class="n">allNestedStages</span><span class="p">))</span>
<span class="k">return</span> <span class="n">estimator_convertible</span> <span class="ow">and</span> <span class="n">evaluator_convertible</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span>
<span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">,</span>
<span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">,</span>
<span class="n">extraMetadata</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">numParamsNotJson</span> <span class="o">=</span> <span class="mi">0</span>
<span class="n">jsonEstimatorParamMaps</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">paramMap</span> <span class="ow">in</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">():</span>
<span class="n">jsonParamMap</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">paramMap</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">jsonParam</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;parent&quot;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">parent</span><span class="p">,</span> <span class="s2">&quot;name&quot;</span><span class="p">:</span> <span class="n">p</span><span class="o">.</span><span class="n">name</span><span class="p">}</span>
<span class="k">if</span> <span class="p">(</span>
<span class="p">(</span><span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Estimator</span><span class="p">)</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">isMetaEstimator</span><span class="p">(</span><span class="n">v</span><span class="p">))</span>
<span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Transformer</span><span class="p">)</span>
<span class="ow">or</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">Evaluator</span><span class="p">)</span>
<span class="p">):</span>
<span class="n">relative_path</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">&quot;epm_</span><span class="si">{</span><span class="n">p</span><span class="o">.</span><span class="n">name</span><span class="si">}{</span><span class="n">numParamsNotJson</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="n">param_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">relative_path</span><span class="p">)</span>
<span class="n">numParamsNotJson</span> <span class="o">+=</span> <span class="mi">1</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">param_path</span><span class="p">)</span>
<span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">relative_path</span>
<span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;isJson&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">False</span>
<span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;ValidatorSharedReadWrite.saveImpl does not handle parameters of type: &quot;</span>
<span class="s2">&quot;MLWritable that are not Estimator/Evaluator/Transformer, and if parameter &quot;</span>
<span class="s2">&quot;is estimator, it cannot be meta estimator such as Validator or OneVsRest&quot;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;value&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span>
<span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;isJson&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span>
<span class="n">jsonParamMap</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jsonParam</span><span class="p">)</span>
<span class="n">jsonEstimatorParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">jsonParamMap</span><span class="p">)</span>
<span class="n">skipParams</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;estimator&quot;</span><span class="p">,</span> <span class="s2">&quot;evaluator&quot;</span><span class="p">,</span> <span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">]</span>
<span class="n">jsonParams</span> <span class="o">=</span> <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">extractJsonParams</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">skipParams</span><span class="p">)</span>
<span class="n">jsonParams</span><span class="p">[</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">jsonEstimatorParamMaps</span>
<span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">saveMetadata</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="p">,</span> <span class="n">jsonParams</span><span class="p">)</span>
<span class="n">evaluatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;evaluator&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">evaluatorPath</span><span class="p">)</span>
<span class="n">estimatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;estimator&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">estimatorPath</span><span class="p">)</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span>
<span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sc</span><span class="p">:</span> <span class="n">SparkContext</span><span class="p">,</span> <span class="n">metadata</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">],</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Evaluator</span><span class="p">,</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]:</span>
<span class="n">evaluatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;evaluator&quot;</span><span class="p">)</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Evaluator</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">evaluatorPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span>
<span class="n">estimatorPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;estimator&quot;</span><span class="p">)</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Estimator</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">estimatorPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span>
<span class="n">uidToParams</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getUidMap</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span>
<span class="n">uidToParams</span><span class="p">[</span><span class="n">evaluator</span><span class="o">.</span><span class="n">uid</span><span class="p">]</span> <span class="o">=</span> <span class="n">evaluator</span>
<span class="n">jsonEstimatorParamMaps</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;paramMap&quot;</span><span class="p">][</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">]</span>
<span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">jsonParamMap</span> <span class="ow">in</span> <span class="n">jsonEstimatorParamMaps</span><span class="p">:</span>
<span class="n">paramMap</span> <span class="o">=</span> <span class="p">{}</span>
<span class="k">for</span> <span class="n">jsonParam</span> <span class="ow">in</span> <span class="n">jsonParamMap</span><span class="p">:</span>
<span class="n">est</span> <span class="o">=</span> <span class="n">uidToParams</span><span class="p">[</span><span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;parent&quot;</span><span class="p">]]</span>
<span class="n">param</span> <span class="o">=</span> <span class="nb">getattr</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;name&quot;</span><span class="p">])</span>
<span class="k">if</span> <span class="s2">&quot;isJson&quot;</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">jsonParam</span> <span class="ow">or</span> <span class="p">(</span><span class="s2">&quot;isJson&quot;</span> <span class="ow">in</span> <span class="n">jsonParam</span> <span class="ow">and</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;isJson&quot;</span><span class="p">]):</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;value&quot;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">relativePath</span> <span class="o">=</span> <span class="n">jsonParam</span><span class="p">[</span><span class="s2">&quot;value&quot;</span><span class="p">]</span>
<span class="n">valueSavedPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">relativePath</span><span class="p">)</span>
<span class="n">value</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">valueSavedPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span>
<span class="n">paramMap</span><span class="p">[</span><span class="n">param</span><span class="p">]</span> <span class="o">=</span> <span class="n">value</span>
<span class="n">estimatorParamMaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">paramMap</span><span class="p">)</span>
<span class="k">return</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">validateParams</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">_ValidatorParams</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">estiamtor</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span>
<span class="n">uidMap</span> <span class="o">=</span> <span class="n">MetaAlgorithmReadWrite</span><span class="o">.</span><span class="n">getUidMap</span><span class="p">(</span><span class="n">estiamtor</span><span class="p">)</span>
<span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="p">[</span><span class="n">evaluator</span><span class="p">]</span> <span class="o">+</span> <span class="nb">list</span><span class="p">(</span><span class="n">uidMap</span><span class="o">.</span><span class="n">values</span><span class="p">()):</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;Validator write will fail because it contains </span><span class="si">{</span><span class="n">elem</span><span class="o">.</span><span class="n">uid</span><span class="si">}</span><span class="s2"> &quot;</span>
<span class="sa">f</span><span class="s2">&quot;which is not writable.&quot;</span>
<span class="p">)</span>
<span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">()</span>
<span class="n">paramErr</span> <span class="o">=</span> <span class="p">(</span>
<span class="s2">&quot;Validator save requires all Params in estimatorParamMaps to apply to &quot;</span>
<span class="s2">&quot;its Estimator, An extraneous Param was found: &quot;</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">paramMap</span> <span class="ow">in</span> <span class="n">estimatorParamMaps</span><span class="p">:</span>
<span class="k">for</span> <span class="n">param</span> <span class="ow">in</span> <span class="n">paramMap</span><span class="p">:</span>
<span class="k">if</span> <span class="n">param</span><span class="o">.</span><span class="n">parent</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">uidMap</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">paramErr</span> <span class="o">+</span> <span class="nb">repr</span><span class="p">(</span><span class="n">param</span><span class="p">))</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span><span class="n">writer</span><span class="p">:</span> <span class="n">MLWriter</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="s2">&quot;persistsubmodels&quot;</span> <span class="ow">in</span> <span class="n">writer</span><span class="o">.</span><span class="n">optionMap</span><span class="p">:</span>
<span class="n">persistSubModelsParam</span> <span class="o">=</span> <span class="n">writer</span><span class="o">.</span><span class="n">optionMap</span><span class="p">[</span><span class="s2">&quot;persistsubmodels&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span>
<span class="k">if</span> <span class="n">persistSubModelsParam</span> <span class="o">==</span> <span class="s2">&quot;true&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="k">elif</span> <span class="n">persistSubModelsParam</span> <span class="o">==</span> <span class="s2">&quot;false&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="kc">False</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;persistSubModels option value </span><span class="si">{</span><span class="n">persistSubModelsParam</span><span class="si">}</span><span class="s2"> is invalid, &quot;</span>
<span class="sa">f</span><span class="s2">&quot;the possible values are True, &#39;True&#39; or False, &#39;False&#39;&quot;</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">writer</span><span class="o">.</span><span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="p">(</span>
<span class="s2">&quot;When persisting tuning models, you can only set persistSubModels to true if the tuning &quot;</span>
<span class="s2">&quot;was done with collectSubModels set to true. To save the sub-models, try rerunning fitting &quot;</span>
<span class="s2">&quot;with collectSubModels set to true.&quot;</span>
<span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">CrossValidatorReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">&quot;CrossValidator&quot;</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">&quot;CrossValidator&quot;</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span>
<span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span>
<span class="p">)</span>
<span class="n">cv</span> <span class="o">=</span> <span class="n">CrossValidator</span><span class="p">(</span>
<span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span>
<span class="p">)</span>
<span class="n">cv</span> <span class="o">=</span> <span class="n">cv</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">])</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">cv</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">])</span>
<span class="k">return</span> <span class="n">cv</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">CrossValidatorWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">CrossValidatorModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span>
<span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span>
<span class="p">)</span>
<span class="n">numFolds</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;paramMap&quot;</span><span class="p">][</span><span class="s2">&quot;numFolds&quot;</span><span class="p">]</span>
<span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;bestModel&quot;</span><span class="p">)</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;avgMetrics&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="s2">&quot;stdMetrics&quot;</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">:</span>
<span class="n">stdMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;stdMetrics&quot;</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">stdMetrics</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">persistSubModels</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;persistSubModels&quot;</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">)</span> <span class="ow">and</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;persistSubModels&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)]</span> <span class="o">*</span> <span class="n">numFolds</span>
<span class="k">for</span> <span class="n">splitIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numFolds</span><span class="p">):</span>
<span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)):</span>
<span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span>
<span class="n">path</span><span class="p">,</span> <span class="s2">&quot;subModels&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;fold</span><span class="si">{</span><span class="n">splitIndex</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">&quot;</span>
<span class="p">)</span>
<span class="n">subModels</span><span class="p">[</span><span class="n">splitIndex</span><span class="p">][</span><span class="n">paramIndex</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span>
<span class="n">modelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">cvModel</span> <span class="o">=</span> <span class="n">CrossValidatorModel</span><span class="p">(</span>
<span class="n">bestModel</span><span class="p">,</span>
<span class="n">avgMetrics</span><span class="o">=</span><span class="n">avgMetrics</span><span class="p">,</span>
<span class="n">subModels</span><span class="o">=</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span>
<span class="n">stdMetrics</span><span class="o">=</span><span class="n">stdMetrics</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">cvModel</span> <span class="o">=</span> <span class="n">cvModel</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">])</span>
<span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimator</span><span class="p">)</span>
<span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="n">cvModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">cvModel</span><span class="o">.</span><span class="n">evaluator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">)</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span>
<span class="n">cvModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">]</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">cvModel</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">CrossValidatorModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span>
<span class="n">persistSubModels</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span>
<span class="bp">self</span>
<span class="p">)</span>
<span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;avgMetrics&quot;</span><span class="p">:</span> <span class="n">instance</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">,</span> <span class="s2">&quot;persistSubModels&quot;</span><span class="p">:</span> <span class="n">persistSubModels</span><span class="p">}</span>
<span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">stdMetrics</span><span class="p">:</span>
<span class="n">extraMetadata</span><span class="p">[</span><span class="s2">&quot;stdMetrics&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">instance</span><span class="o">.</span><span class="n">stdMetrics</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span>
<span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;bestModel&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">)</span>
<span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span>
<span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">)</span>
<span class="n">subModelsPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;subModels&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">splitIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">()):</span>
<span class="n">splitPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">subModelsPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;fold</span><span class="si">{</span><span class="n">splitIndex</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">())):</span>
<span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">splitPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span><span class="p">[</span><span class="n">splitIndex</span><span class="p">][</span><span class="n">paramIndex</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">_CrossValidatorParams</span><span class="p">(</span><span class="n">_ValidatorParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`CrossValidator` and :py:class:`CrossValidatorModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">numFolds</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;numFolds&quot;</span><span class="p">,</span>
<span class="s2">&quot;number of folds for cross validation&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toInt</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">foldCol</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;foldCol&quot;</span><span class="p">,</span>
<span class="s2">&quot;Param for the column name of user &quot;</span>
<span class="o">+</span> <span class="s2">&quot;specified fold number. Once this is specified, :py:class:`CrossValidator` &quot;</span>
<span class="o">+</span> <span class="s2">&quot;won&#39;t do random k-fold split. Note that this column should be integer type &quot;</span>
<span class="o">+</span> <span class="s2">&quot;with range [0, numFolds) and Spark will throw exception on out-of-range &quot;</span>
<span class="o">+</span> <span class="s2">&quot;fold numbers.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_CrossValidatorParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">numFolds</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">foldCol</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of numFolds or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of foldCol or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">foldCol</span><span class="p">)</span>
<div class="viewcode-block" id="CrossValidator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator">[docs]</a><span class="k">class</span> <span class="nc">CrossValidator</span><span class="p">(</span>
<span class="n">Estimator</span><span class="p">[</span><span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">],</span>
<span class="n">_CrossValidatorParams</span><span class="p">,</span>
<span class="n">HasParallelism</span><span class="p">,</span>
<span class="n">HasCollectSubModels</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;CrossValidator&quot;</span><span class="p">],</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> K-fold cross validation performs model selection by splitting the dataset into a set of</span>
<span class="sd"> non-overlapping randomly partitioned folds which are used as separate training and test datasets</span>
<span class="sd"> e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs,</span>
<span class="sd"> each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the</span>
<span class="sd"> test set exactly once.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.classification import LogisticRegression</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.evaluation import BinaryClassificationEvaluator</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, CrossValidatorModel</span>
<span class="sd"> &gt;&gt;&gt; import tempfile</span>
<span class="sd"> &gt;&gt;&gt; dataset = spark.createDataFrame(</span>
<span class="sd"> ... [(Vectors.dense([0.0]), 0.0),</span>
<span class="sd"> ... (Vectors.dense([0.4]), 1.0),</span>
<span class="sd"> ... (Vectors.dense([0.5]), 0.0),</span>
<span class="sd"> ... (Vectors.dense([0.6]), 1.0),</span>
<span class="sd"> ... (Vectors.dense([1.0]), 1.0)] * 10,</span>
<span class="sd"> ... [&quot;features&quot;, &quot;label&quot;])</span>
<span class="sd"> &gt;&gt;&gt; lr = LogisticRegression()</span>
<span class="sd"> &gt;&gt;&gt; grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()</span>
<span class="sd"> &gt;&gt;&gt; evaluator = BinaryClassificationEvaluator()</span>
<span class="sd"> &gt;&gt;&gt; cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,</span>
<span class="sd"> ... parallelism=2)</span>
<span class="sd"> &gt;&gt;&gt; cvModel = cv.fit(dataset)</span>
<span class="sd"> &gt;&gt;&gt; cvModel.getNumFolds()</span>
<span class="sd"> 3</span>
<span class="sd"> &gt;&gt;&gt; cvModel.avgMetrics[0]</span>
<span class="sd"> 0.5</span>
<span class="sd"> &gt;&gt;&gt; path = tempfile.mkdtemp()</span>
<span class="sd"> &gt;&gt;&gt; model_path = path + &quot;/model&quot;</span>
<span class="sd"> &gt;&gt;&gt; cvModel.write().save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; cvModelRead = CrossValidatorModel.read().load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; cvModelRead.avgMetrics</span>
<span class="sd"> [0.5, ...</span>
<span class="sd"> &gt;&gt;&gt; evaluator.evaluate(cvModel.transform(dataset))</span>
<span class="sd"> 0.8333...</span>
<span class="sd"> &gt;&gt;&gt; evaluator.evaluate(cvModelRead.transform(dataset))</span>
<span class="sd"> 0.8333...</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">numFolds</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">foldCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\</span>
<span class="sd"> seed=None, parallelism=1, collectSubModels=False, foldCol=&quot;&quot;)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="CrossValidator.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">numFolds</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">foldCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\</span>
<span class="sd"> seed=None, parallelism=1, collectSubModels=False, foldCol=&quot;&quot;):</span>
<span class="sd"> Sets params for cross validator.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setEstimator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEstimator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`estimator`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setEstimatorParamMaps"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEstimatorParamMaps">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`estimatorParamMaps`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setEvaluator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setEvaluator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`evaluator`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">evaluator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setNumFolds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setNumFolds">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`numFolds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">numFolds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setFoldCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setFoldCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`foldCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">foldCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`parallelism`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.setCollectSubModels"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.setCollectSubModels">[docs]</a> <span class="k">def</span> <span class="nf">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`collectSubModels`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">collectSubModels</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">_gen_avg_and_std_metrics</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Tuple</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]:</span>
<span class="n">avg_metrics</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="n">std_metrics</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>
<span class="k">return</span> <span class="nb">list</span><span class="p">(</span><span class="n">avg_metrics</span><span class="p">),</span> <span class="nb">list</span><span class="p">(</span><span class="n">std_metrics</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">:</span>
<span class="n">est</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span>
<span class="n">epm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="n">numModels</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span>
<span class="n">eva</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span>
<span class="n">nFolds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span>
<span class="n">metrics_all</span> <span class="o">=</span> <span class="p">[[</span><span class="mf">0.0</span><span class="p">]</span> <span class="o">*</span> <span class="n">numModels</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">)]</span>
<span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numModels</span><span class="p">))</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">collectSubModelsParam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span>
<span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numModels</span><span class="p">)]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">)]</span>
<span class="n">datasets</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_kFold</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span>
<span class="n">validation</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">datasets</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span>
<span class="n">inheritable_thread_target</span><span class="p">,</span>
<span class="n">_parallelFitTasks</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">eva</span><span class="p">,</span> <span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">,</span> <span class="n">collectSubModelsParam</span><span class="p">),</span>
<span class="p">)</span>
<span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">subModel</span> <span class="ow">in</span> <span class="n">pool</span><span class="o">.</span><span class="n">imap_unordered</span><span class="p">(</span><span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="p">(),</span> <span class="n">tasks</span><span class="p">):</span>
<span class="n">metrics_all</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric</span>
<span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">subModels</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">subModel</span>
<span class="n">validation</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="n">train</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="n">metrics</span><span class="p">,</span> <span class="n">std_metrics</span> <span class="o">=</span> <span class="n">CrossValidator</span><span class="o">.</span><span class="n">_gen_avg_and_std_metrics</span><span class="p">(</span><span class="n">metrics_all</span><span class="p">)</span>
<span class="k">if</span> <span class="n">eva</span><span class="o">.</span><span class="n">isLargerBetter</span><span class="p">():</span>
<span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">metrics</span><span class="p">)</span>
<span class="n">bestModel</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">bestIndex</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span>
<span class="n">CrossValidatorModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">metrics</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span> <span class="n">std_metrics</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_kFold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">Tuple</span><span class="p">[</span><span class="n">DataFrame</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">]]:</span>
<span class="n">nFolds</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numFolds</span><span class="p">)</span>
<span class="n">foldCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">foldCol</span><span class="p">)</span>
<span class="n">datasets</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">foldCol</span><span class="p">:</span>
<span class="c1"># Do random k-fold split.</span>
<span class="n">seed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">h</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="n">nFolds</span>
<span class="n">randCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> <span class="o">+</span> <span class="s2">&quot;_rand&quot;</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">&quot;*&quot;</span><span class="p">,</span> <span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">alias</span><span class="p">(</span><span class="n">randCol</span><span class="p">))</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span>
<span class="n">validateLB</span> <span class="o">=</span> <span class="n">i</span> <span class="o">*</span> <span class="n">h</span>
<span class="n">validateUB</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">h</span>
<span class="n">condition</span> <span class="o">=</span> <span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="n">validateLB</span><span class="p">)</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o">&lt;</span> <span class="n">validateUB</span><span class="p">)</span>
<span class="n">validation</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="o">~</span><span class="n">condition</span><span class="p">)</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">train</span><span class="p">,</span> <span class="n">validation</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="c1"># Use user-specified fold numbers.</span>
<span class="k">def</span> <span class="nf">checker</span><span class="p">(</span><span class="n">foldNum</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">bool</span><span class="p">:</span>
<span class="k">if</span> <span class="n">foldNum</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">foldNum</span> <span class="o">&gt;=</span> <span class="n">nFolds</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Fold number must be in range [0, </span><span class="si">%s</span><span class="s2">), but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">nFolds</span><span class="p">,</span> <span class="n">foldNum</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="kc">True</span>
<span class="n">checker_udf</span> <span class="o">=</span> <span class="n">UserDefinedFunction</span><span class="p">(</span><span class="n">checker</span><span class="p">,</span> <span class="n">BooleanType</span><span class="p">())</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">nFolds</span><span class="p">):</span>
<span class="n">training</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">checker_udf</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">foldCol</span><span class="p">])</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">col</span><span class="p">(</span><span class="n">foldCol</span><span class="p">)</span> <span class="o">!=</span> <span class="n">lit</span><span class="p">(</span><span class="n">i</span><span class="p">)))</span>
<span class="n">validation</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span>
<span class="n">checker_udf</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">foldCol</span><span class="p">])</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">col</span><span class="p">(</span><span class="n">foldCol</span><span class="p">)</span> <span class="o">==</span> <span class="n">lit</span><span class="p">(</span><span class="n">i</span><span class="p">))</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">training</span><span class="o">.</span><span class="n">rdd</span><span class="o">.</span><span class="n">getNumPartitions</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">training</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The training data at fold </span><span class="si">%s</span><span class="s2"> is empty.&quot;</span> <span class="o">%</span> <span class="n">i</span><span class="p">)</span>
<span class="k">if</span> <span class="n">validation</span><span class="o">.</span><span class="n">rdd</span><span class="o">.</span><span class="n">getNumPartitions</span><span class="p">()</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">or</span> <span class="nb">len</span><span class="p">(</span><span class="n">validation</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;The validation data at fold </span><span class="si">%s</span><span class="s2"> is empty.&quot;</span> <span class="o">%</span> <span class="n">i</span><span class="p">)</span>
<span class="n">datasets</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">training</span><span class="p">,</span> <span class="n">validation</span><span class="p">))</span>
<span class="k">return</span> <span class="n">datasets</span>
<div class="viewcode-block" id="CrossValidator.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This copies creates a deep copy of</span>
<span class="sd"> the embedded paramMap, and copies the embedded and extra parameters over.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`CrossValidator`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">newCV</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">):</span>
<span class="n">newCV</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span>
<span class="c1"># estimatorParamMaps remain the same</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">):</span>
<span class="n">newCV</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span>
<span class="k">return</span> <span class="n">newCV</span></div>
<div class="viewcode-block" id="CrossValidator.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLWriter instance for this ML instance.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">return</span> <span class="n">CrossValidatorWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidator.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidator.html#pyspark.ml.tuning.CrossValidator.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">CrossValidatorReader</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLReader instance for this class.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">CrossValidatorReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidator&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java CrossValidator, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span>
<span class="n">numFolds</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">()</span>
<span class="n">seed</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">()</span>
<span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span>
<span class="n">collectSubModels</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span>
<span class="n">foldCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">()</span>
<span class="c1"># Create a new instance of this stage.</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">epms</span><span class="p">,</span>
<span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span><span class="p">,</span>
<span class="n">numFolds</span><span class="o">=</span><span class="n">numFolds</span><span class="p">,</span>
<span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span>
<span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="o">=</span><span class="n">collectSubModels</span><span class="p">,</span>
<span class="n">foldCol</span><span class="o">=</span><span class="n">foldCol</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java CrossValidator. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidator</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">&quot;org.apache.spark.ml.tuning.CrossValidator&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="p">(</span><span class="n">epms</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="n">evaluator</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setNumFolds</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setFoldCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">())</span>
<span class="k">return</span> <span class="n">_java_obj</span></div>
<div class="viewcode-block" id="CrossValidatorModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel">[docs]</a><span class="k">class</span> <span class="nc">CrossValidatorModel</span><span class="p">(</span>
<span class="n">Model</span><span class="p">,</span> <span class="n">_CrossValidatorParams</span><span class="p">,</span> <span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">],</span> <span class="n">MLWritable</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> CrossValidatorModel contains the model with the highest average cross-validation</span>
<span class="sd"> metric across folds and uses this model to transform input data. CrossValidatorModel</span>
<span class="sd"> also tracks the metrics for each param map evaluated.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> Since version 3.3.0, CrossValidatorModel contains a new attribute &quot;stdMetrics&quot;,</span>
<span class="sd"> which represent standard deviation of metrics for each paramMap in</span>
<span class="sd"> CrossValidator.estimatorParamMaps.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span>
<span class="n">avgMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subModels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">stdMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1">#: best model from cross validation</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span> <span class="o">=</span> <span class="n">bestModel</span>
<span class="c1">#: Average cross-validation metrics for each paramMap in</span>
<span class="c1">#: CrossValidator.estimatorParamMaps, in the corresponding order.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">avgMetrics</span> <span class="ow">or</span> <span class="p">[]</span>
<span class="c1">#: sub model list from cross validation</span>
<span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="n">subModels</span>
<span class="c1">#: standard deviation of metrics for each paramMap in</span>
<span class="c1">#: CrossValidator.estimatorParamMaps, in the corresponding order.</span>
<span class="bp">self</span><span class="o">.</span><span class="n">stdMetrics</span> <span class="o">=</span> <span class="n">stdMetrics</span> <span class="ow">or</span> <span class="p">[]</span>
<span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<div class="viewcode-block" id="CrossValidatorModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This copies the underlying bestModel,</span>
<span class="sd"> creates a deep copy of the embedded paramMap, and</span>
<span class="sd"> copies the embedded and extra parameters over.</span>
<span class="sd"> It does not copy the extra Params into the subModels.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`CrossValidatorModel`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">bestModel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span>
<span class="n">avgMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">[</span><span class="n">sub_model</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span>
<span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span>
<span class="p">]</span>
<span class="n">stdMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">stdMetrics</span><span class="p">)</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span>
<span class="n">CrossValidatorModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">avgMetrics</span><span class="p">,</span> <span class="n">subModels</span><span class="p">,</span> <span class="n">stdMetrics</span><span class="p">),</span> <span class="n">extra</span><span class="o">=</span><span class="n">extra</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidatorModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLWriter instance for this ML instance.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">return</span> <span class="n">CrossValidatorModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="CrossValidatorModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.CrossValidatorModel.html#pyspark.ml.tuning.CrossValidatorModel.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">CrossValidatorModelReader</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLReader instance for this class.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">CrossValidatorModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;CrossValidatorModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java CrossValidatorModel, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">bestModel</span><span class="p">())</span>
<span class="n">avgMetrics</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">())</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">bestModel</span><span class="o">=</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">avgMetrics</span><span class="o">=</span><span class="n">avgMetrics</span><span class="p">)</span>
<span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;evaluator&quot;</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span>
<span class="s2">&quot;estimator&quot;</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span>
<span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span>
<span class="s2">&quot;numFolds&quot;</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">(),</span>
<span class="s2">&quot;foldCol&quot;</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">(),</span>
<span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="p">{</span><span class="n">param_name</span><span class="p">:</span> <span class="n">param_val</span><span class="p">})</span>
<span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">hasSubModels</span><span class="p">():</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">[</span><span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">sub_model</span><span class="p">)</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span>
<span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">subModels</span><span class="p">()</span>
<span class="p">]</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.tuning.CrossValidatorModel&quot;</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span>
<span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">(),</span>
<span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">avgMetrics</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">CrossValidatorModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span>
<span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;evaluator&quot;</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span>
<span class="s2">&quot;estimator&quot;</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span>
<span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span>
<span class="s2">&quot;numFolds&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getNumFolds</span><span class="p">(),</span>
<span class="s2">&quot;foldCol&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFoldCol</span><span class="p">(),</span>
<span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">java_param</span> <span class="o">=</span> <span class="n">_java_obj</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">param_name</span><span class="p">)</span>
<span class="n">pair</span> <span class="o">=</span> <span class="n">java_param</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">param_val</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">pair</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">java_sub_models</span> <span class="o">=</span> <span class="p">[</span>
<span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">sub_model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">fold_sub_models</span><span class="p">]</span>
<span class="k">for</span> <span class="n">fold_sub_models</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span>
<span class="p">]</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setSubModels</span><span class="p">(</span><span class="n">java_sub_models</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_java_obj</span></div>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">TrainValidationSplitReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span>
<span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span>
<span class="p">)</span>
<span class="n">tvs</span> <span class="o">=</span> <span class="n">TrainValidationSplit</span><span class="p">(</span>
<span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span>
<span class="p">)</span>
<span class="n">tvs</span> <span class="o">=</span> <span class="n">tvs</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">])</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">tvs</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">])</span>
<span class="k">return</span> <span class="n">tvs</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">TrainValidationSplitWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">TrainValidationSplitModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">metadata</span><span class="p">,</span> <span class="n">estimator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">,</span> <span class="n">estimatorParamMaps</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">load</span><span class="p">(</span>
<span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">metadata</span>
<span class="p">)</span>
<span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;bestModel&quot;</span><span class="p">)</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;validationMetrics&quot;</span><span class="p">]</span>
<span class="n">persistSubModels</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;persistSubModels&quot;</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">)</span> <span class="ow">and</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;persistSubModels&quot;</span><span class="p">]</span>
<span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="p">)):</span>
<span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;subModels&quot;</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">subModels</span><span class="p">[</span><span class="n">paramIndex</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span>
<span class="n">modelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">tvsModel</span> <span class="o">=</span> <span class="n">TrainValidationSplitModel</span><span class="p">(</span>
<span class="n">bestModel</span><span class="p">,</span>
<span class="n">validationMetrics</span><span class="o">=</span><span class="n">validationMetrics</span><span class="p">,</span>
<span class="n">subModels</span><span class="o">=</span><span class="n">cast</span><span class="p">(</span><span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]],</span> <span class="n">subModels</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">tvsModel</span> <span class="o">=</span> <span class="n">tvsModel</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">])</span>
<span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">estimator</span><span class="p">,</span> <span class="n">estimator</span><span class="p">)</span>
<span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">,</span> <span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="n">tvsModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">tvsModel</span><span class="o">.</span><span class="n">evaluator</span><span class="p">,</span> <span class="n">evaluator</span><span class="p">)</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span>
<span class="n">tvsModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">]</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">tvsModel</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">TrainValidationSplitModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span>
<span class="n">persistSubModels</span> <span class="o">=</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">getValidatorModelWriterPersistSubModelsParam</span><span class="p">(</span>
<span class="bp">self</span>
<span class="p">)</span>
<span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;validationMetrics&quot;</span><span class="p">:</span> <span class="n">instance</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">,</span>
<span class="s2">&quot;persistSubModels&quot;</span><span class="p">:</span> <span class="n">persistSubModels</span><span class="p">,</span>
<span class="p">}</span>
<span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span>
<span class="n">bestModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;bestModel&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">bestModelPath</span><span class="p">)</span>
<span class="k">if</span> <span class="n">persistSubModels</span><span class="p">:</span>
<span class="k">if</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="n">_save_with_persist_submodels_no_submodels_found_err</span><span class="p">)</span>
<span class="n">subModelsPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;subModels&quot;</span><span class="p">)</span>
<span class="k">for</span> <span class="n">paramIndex</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">getEstimatorParamMaps</span><span class="p">())):</span>
<span class="n">modelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">subModelsPath</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;</span><span class="si">{</span><span class="n">paramIndex</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">subModels</span><span class="p">[</span><span class="n">paramIndex</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">modelPath</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">_TrainValidationSplitParams</span><span class="p">(</span><span class="n">_ValidatorParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`TrainValidationSplit` and :py:class:`TrainValidationSplitModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">trainRatio</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;trainRatio&quot;</span><span class="p">,</span>
<span class="s2">&quot;Param for ratio between train and</span><span class="se">\</span>
<span class="s2"> validation data. Must be between 0 and 1.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_TrainValidationSplitParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">trainRatio</span><span class="o">=</span><span class="mf">0.75</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of trainRatio or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainRatio</span><span class="p">)</span>
<div class="viewcode-block" id="TrainValidationSplit"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit">[docs]</a><span class="k">class</span> <span class="nc">TrainValidationSplit</span><span class="p">(</span>
<span class="n">Estimator</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">],</span>
<span class="n">_TrainValidationSplitParams</span><span class="p">,</span>
<span class="n">HasParallelism</span><span class="p">,</span>
<span class="n">HasCollectSubModels</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">],</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Validation for hyper-parameter tuning. Randomly splits the input dataset into train and</span>
<span class="sd"> validation sets, and uses evaluation metric on the validation set to select the best model.</span>
<span class="sd"> Similar to :class:`CrossValidator`, but only splits the set once.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.classification import LogisticRegression</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.evaluation import BinaryClassificationEvaluator</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.tuning import TrainValidationSplitModel</span>
<span class="sd"> &gt;&gt;&gt; import tempfile</span>
<span class="sd"> &gt;&gt;&gt; dataset = spark.createDataFrame(</span>
<span class="sd"> ... [(Vectors.dense([0.0]), 0.0),</span>
<span class="sd"> ... (Vectors.dense([0.4]), 1.0),</span>
<span class="sd"> ... (Vectors.dense([0.5]), 0.0),</span>
<span class="sd"> ... (Vectors.dense([0.6]), 1.0),</span>
<span class="sd"> ... (Vectors.dense([1.0]), 1.0)] * 10,</span>
<span class="sd"> ... [&quot;features&quot;, &quot;label&quot;]).repartition(1)</span>
<span class="sd"> &gt;&gt;&gt; lr = LogisticRegression()</span>
<span class="sd"> &gt;&gt;&gt; grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()</span>
<span class="sd"> &gt;&gt;&gt; evaluator = BinaryClassificationEvaluator()</span>
<span class="sd"> &gt;&gt;&gt; tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator,</span>
<span class="sd"> ... parallelism=1, seed=42)</span>
<span class="sd"> &gt;&gt;&gt; tvsModel = tvs.fit(dataset)</span>
<span class="sd"> &gt;&gt;&gt; tvsModel.getTrainRatio()</span>
<span class="sd"> 0.75</span>
<span class="sd"> &gt;&gt;&gt; tvsModel.validationMetrics</span>
<span class="sd"> [0.5, ...</span>
<span class="sd"> &gt;&gt;&gt; path = tempfile.mkdtemp()</span>
<span class="sd"> &gt;&gt;&gt; model_path = path + &quot;/model&quot;</span>
<span class="sd"> &gt;&gt;&gt; tvsModel.write().save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; tvsModelRead = TrainValidationSplitModel.read().load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; tvsModelRead.validationMetrics</span>
<span class="sd"> [0.5, ...</span>
<span class="sd"> &gt;&gt;&gt; evaluator.evaluate(tvsModel.transform(dataset))</span>
<span class="sd"> 0.833...</span>
<span class="sd"> &gt;&gt;&gt; evaluator.evaluate(tvsModelRead.transform(dataset))</span>
<span class="sd"> 0.833...</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">trainRatio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.75</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \</span>
<span class="sd"> trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="TrainValidationSplit.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setParams">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">estimator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Estimator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">evaluator</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Evaluator</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">trainRatio</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.75</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, estimator=None, estimatorParamMaps=None, evaluator=None, \</span>
<span class="sd"> trainRatio=0.75, parallelism=1, collectSubModels=False, seed=None):</span>
<span class="sd"> Sets params for the train validation split.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setEstimator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEstimator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Estimator</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`estimator`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setEstimatorParamMaps"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEstimatorParamMaps">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEstimatorParamMaps</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`estimatorParamMaps`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setEvaluator"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setEvaluator">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Evaluator</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`evaluator`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">evaluator</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setTrainRatio"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setTrainRatio">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`trainRatio`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">trainRatio</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`parallelism`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.setCollectSubModels"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.setCollectSubModels">[docs]</a> <span class="k">def</span> <span class="nf">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`collectSubModels`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">collectSubModels</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">:</span>
<span class="n">est</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">)</span>
<span class="n">epm</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimatorParamMaps</span><span class="p">)</span>
<span class="n">numModels</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">epm</span><span class="p">)</span>
<span class="n">eva</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">)</span>
<span class="n">tRatio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">trainRatio</span><span class="p">)</span>
<span class="n">seed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">seed</span><span class="p">)</span>
<span class="n">randCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span> <span class="o">+</span> <span class="s2">&quot;_rand&quot;</span>
<span class="n">df</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s2">&quot;*&quot;</span><span class="p">,</span> <span class="n">rand</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span><span class="o">.</span><span class="n">alias</span><span class="p">(</span><span class="n">randCol</span><span class="p">))</span>
<span class="n">condition</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">randCol</span><span class="p">]</span> <span class="o">&gt;=</span> <span class="n">tRatio</span>
<span class="n">validation</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">condition</span><span class="p">)</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="o">~</span><span class="n">condition</span><span class="p">)</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="kc">None</span>
<span class="n">collectSubModelsParam</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span>
<span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numModels</span><span class="p">)]</span>
<span class="n">tasks</span> <span class="o">=</span> <span class="nb">map</span><span class="p">(</span>
<span class="n">inheritable_thread_target</span><span class="p">,</span>
<span class="n">_parallelFitTasks</span><span class="p">(</span><span class="n">est</span><span class="p">,</span> <span class="n">train</span><span class="p">,</span> <span class="n">eva</span><span class="p">,</span> <span class="n">validation</span><span class="p">,</span> <span class="n">epm</span><span class="p">,</span> <span class="n">collectSubModelsParam</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numModels</span><span class="p">))</span>
<span class="n">metrics</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">numModels</span>
<span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">metric</span><span class="p">,</span> <span class="n">subModel</span> <span class="ow">in</span> <span class="n">pool</span><span class="o">.</span><span class="n">imap_unordered</span><span class="p">(</span><span class="k">lambda</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="p">(),</span> <span class="n">tasks</span><span class="p">):</span>
<span class="n">metrics</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">metric</span>
<span class="k">if</span> <span class="n">collectSubModelsParam</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">subModels</span><span class="p">[</span><span class="n">j</span><span class="p">]</span> <span class="o">=</span> <span class="n">subModel</span>
<span class="n">train</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="n">validation</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="k">if</span> <span class="n">eva</span><span class="o">.</span><span class="n">isLargerBetter</span><span class="p">():</span>
<span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">bestIndex</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmin</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">))</span>
<span class="n">bestModel</span> <span class="o">=</span> <span class="n">est</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epm</span><span class="p">[</span><span class="n">bestIndex</span><span class="p">])</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span>
<span class="n">TrainValidationSplitModel</span><span class="p">(</span>
<span class="n">bestModel</span><span class="p">,</span>
<span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">],</span> <span class="n">metrics</span><span class="p">),</span>
<span class="n">subModels</span><span class="p">,</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="p">)</span>
<span class="p">)</span>
<div class="viewcode-block" id="TrainValidationSplit.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This copies creates a deep copy of</span>
<span class="sd"> the embedded paramMap, and copies the embedded and extra parameters over.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`TrainValidationSplit`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">newTVS</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">estimator</span><span class="p">):</span>
<span class="n">newTVS</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEstimator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span>
<span class="c1"># estimatorParamMaps remain the same</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">evaluator</span><span class="p">):</span>
<span class="n">newTVS</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getEvaluator</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span>
<span class="k">return</span> <span class="n">newTVS</span></div>
<div class="viewcode-block" id="TrainValidationSplit.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLWriter instance for this ML instance.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">return</span> <span class="n">TrainValidationSplitWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplit.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplit.html#pyspark.ml.tuning.TrainValidationSplit.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">TrainValidationSplitReader</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLReader instance for this class.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">TrainValidationSplitReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplit&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java TrainValidationSplit, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span><span class="n">java_stage</span><span class="p">)</span>
<span class="n">trainRatio</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">()</span>
<span class="n">seed</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">()</span>
<span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span>
<span class="n">collectSubModels</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">()</span>
<span class="c1"># Create a new instance of this stage.</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">estimator</span><span class="o">=</span><span class="n">estimator</span><span class="p">,</span>
<span class="n">estimatorParamMaps</span><span class="o">=</span><span class="n">epms</span><span class="p">,</span>
<span class="n">evaluator</span><span class="o">=</span><span class="n">evaluator</span><span class="p">,</span>
<span class="n">trainRatio</span><span class="o">=</span><span class="n">trainRatio</span><span class="p">,</span>
<span class="n">seed</span><span class="o">=</span><span class="n">seed</span><span class="p">,</span>
<span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span>
<span class="n">collectSubModels</span><span class="o">=</span><span class="n">collectSubModels</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplit</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.tuning.TrainValidationSplit&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimatorParamMaps</span><span class="p">(</span><span class="n">epms</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEvaluator</span><span class="p">(</span><span class="n">evaluator</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setEstimator</span><span class="p">(</span><span class="n">estimator</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setTrainRatio</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setCollectSubModels</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getCollectSubModels</span><span class="p">())</span>
<span class="k">return</span> <span class="n">_java_obj</span></div>
<div class="viewcode-block" id="TrainValidationSplitModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel">[docs]</a><span class="k">class</span> <span class="nc">TrainValidationSplitModel</span><span class="p">(</span>
<span class="n">Model</span><span class="p">,</span> <span class="n">_TrainValidationSplitParams</span><span class="p">,</span> <span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">],</span> <span class="n">MLWritable</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model from train validation split.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span><span class="p">,</span>
<span class="n">validationMetrics</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subModels</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="n">Model</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="c1">#: best model from train validation split</span>
<span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span> <span class="o">=</span> <span class="n">bestModel</span>
<span class="c1">#: evaluated validation metrics</span>
<span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">validationMetrics</span> <span class="ow">or</span> <span class="p">[]</span>
<span class="c1">#: sub models from train validation split</span>
<span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="n">subModels</span>
<span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span>
<div class="viewcode-block" id="TrainValidationSplitModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This copies the underlying bestModel,</span>
<span class="sd"> creates a deep copy of the embedded paramMap, and</span>
<span class="sd"> copies the embedded and extra parameters over.</span>
<span class="sd"> And, this creates a shallow copy of the validationMetrics.</span>
<span class="sd"> It does not copy the extra Params into the subModels.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`TrainValidationSplitModel`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">bestModel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span>
<span class="n">validationMetrics</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">)</span>
<span class="k">assert</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span><span class="p">]</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span>
<span class="n">TrainValidationSplitModel</span><span class="p">(</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">validationMetrics</span><span class="p">,</span> <span class="n">subModels</span><span class="p">),</span> <span class="n">extra</span><span class="o">=</span><span class="n">extra</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplitModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.write">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLWriter instance for this ML instance.&quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">_ValidatorSharedReadWrite</span><span class="o">.</span><span class="n">is_java_convertible</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">return</span> <span class="n">TrainValidationSplitModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div>
<div class="viewcode-block" id="TrainValidationSplitModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.tuning.TrainValidationSplitModel.html#pyspark.ml.tuning.TrainValidationSplitModel.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">TrainValidationSplitModelReader</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns an MLReader instance for this class.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="n">TrainValidationSplitModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;TrainValidationSplitModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="c1"># Load information from java_stage to the instance.</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">bestModel</span><span class="p">:</span> <span class="n">Model</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">bestModel</span><span class="p">())</span>
<span class="n">validationMetrics</span> <span class="o">=</span> <span class="n">_java2py</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">())</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">cls</span><span class="p">)</span><span class="o">.</span><span class="n">_from_java_impl</span><span class="p">(</span>
<span class="n">java_stage</span>
<span class="p">)</span>
<span class="c1"># Create a new instance of this stage.</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">bestModel</span><span class="o">=</span><span class="n">bestModel</span><span class="p">,</span> <span class="n">validationMetrics</span><span class="o">=</span><span class="n">validationMetrics</span><span class="p">)</span>
<span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;evaluator&quot;</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span>
<span class="s2">&quot;estimator&quot;</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span>
<span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span>
<span class="s2">&quot;trainRatio&quot;</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">(),</span>
<span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="p">{</span><span class="n">param_name</span><span class="p">:</span> <span class="n">param_val</span><span class="p">})</span>
<span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">hasSubModels</span><span class="p">():</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">sub_model</span><span class="p">)</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">subModels</span><span class="p">()</span>
<span class="p">]</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.tuning.TrainValidationSplitModel&quot;</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span>
<span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">bestModel</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">(),</span>
<span class="n">_py2java</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">validationMetrics</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">estimator</span><span class="p">,</span> <span class="n">epms</span><span class="p">,</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="nb">super</span><span class="p">(</span><span class="n">TrainValidationSplitModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java_impl</span><span class="p">()</span>
<span class="n">params</span> <span class="o">=</span> <span class="p">{</span>
<span class="s2">&quot;evaluator&quot;</span><span class="p">:</span> <span class="n">evaluator</span><span class="p">,</span>
<span class="s2">&quot;estimator&quot;</span><span class="p">:</span> <span class="n">estimator</span><span class="p">,</span>
<span class="s2">&quot;estimatorParamMaps&quot;</span><span class="p">:</span> <span class="n">epms</span><span class="p">,</span>
<span class="s2">&quot;trainRatio&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getTrainRatio</span><span class="p">(),</span>
<span class="s2">&quot;seed&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">getSeed</span><span class="p">(),</span>
<span class="p">}</span>
<span class="k">for</span> <span class="n">param_name</span><span class="p">,</span> <span class="n">param_val</span> <span class="ow">in</span> <span class="n">params</span><span class="o">.</span><span class="n">items</span><span class="p">():</span>
<span class="n">java_param</span> <span class="o">=</span> <span class="n">_java_obj</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="n">param_name</span><span class="p">)</span>
<span class="n">pair</span> <span class="o">=</span> <span class="n">java_param</span><span class="o">.</span><span class="n">w</span><span class="p">(</span><span class="n">param_val</span><span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">pair</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">java_sub_models</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">cast</span><span class="p">(</span><span class="n">JavaParams</span><span class="p">,</span> <span class="n">sub_model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">sub_model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">subModels</span>
<span class="p">]</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setSubModels</span><span class="p">(</span><span class="n">java_sub_models</span><span class="p">)</span>
<span class="k">return</span> <span class="n">_java_obj</span></div>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">doctest</span>
<span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">SparkSession</span>
<span class="n">globs</span> <span class="o">=</span> <span class="nb">globals</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="c1"># The small batch size here ensures that we see multiple batches,</span>
<span class="c1"># even in these small test examples:</span>
<span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">&quot;local[2]&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">&quot;ml.tuning tests&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;sc&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">sc</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;spark&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span>
<span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span><span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span> <span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span><span class="p">)</span>
<span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span>
<span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span>
<span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</div>
<!-- Previous / next buttons -->
<div class='prev-next-area'>
</div>
</main>
</div>
</div>
<script src="../../../_static/scripts/pydata-sphinx-theme.js?digest=1999514e3f237ded88cf"></script>
<footer class="footer mt-5 mt-md-0">
<div class="container">
<div class="footer-item">
<p class="copyright">
&copy; Copyright .<br>
</p>
</div>
<div class="footer-item">
<p class="sphinx-version">
Created using <a href="http://sphinx-doc.org/">Sphinx</a> 3.0.4.<br>
</p>
</div>
</div>
</footer>
</body>
</html>