blob: d81dfcb5675d851b199cf9627afeca711fa1cc7a [file] [log] [blame]
<!DOCTYPE html>
<html >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>pyspark.ml.classification &#8212; PySpark 4.0.0-preview1 documentation</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "light";
</script>
<!-- Loaded before other Sphinx assets -->
<link href="../../../_static/styles/theme.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/styles/bootstrap.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/vendor/fontawesome/6.1.2/css/all.min.css?digest=e353d410970836974a52" rel="stylesheet" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-solid-900.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-brands-400.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-regular-400.woff2" />
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" />
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../../_static/scripts/bootstrap.js?digest=e353d410970836974a52" />
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=e353d410970836974a52" />
<script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
<script src="../../../_static/jquery.js"></script>
<script src="../../../_static/underscore.js"></script>
<script src="../../../_static/doctools.js"></script>
<script src="../../../_static/clipboard.min.js"></script>
<script src="../../../_static/copybutton.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/pyspark/ml/classification';</script>
<link rel="canonical" href="https://spark.apache.org/docs/latest/api/python/_modules/pyspark/ml/classification.html" />
<link rel="search" title="Search" href="../../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="docsearch:language" content="None">
<!-- Matomo -->
<script type="text/javascript">
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
_paq.push(["disableCookies"]);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '40']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<a class="skip-link" href="#main-content">Skip to main content</a>
<input type="checkbox"
class="sidebar-toggle"
name="__primary"
id="__primary"/>
<label class="overlay overlay-primary" for="__primary"></label>
<input type="checkbox"
class="sidebar-toggle"
name="__secondary"
id="__secondary"/>
<label class="overlay overlay-secondary" for="__secondary"></label>
<div class="search-button__wrapper">
<div class="search-button__overlay"></div>
<div class="search-button__search-container">
<form class="bd-search d-flex align-items-center"
action="../../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
id="search-input"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form></div>
</div>
<nav class="bd-header navbar navbar-expand-lg bd-navbar">
<div class="bd-header__inner bd-page-width">
<label class="sidebar-toggle primary-toggle" for="__primary">
<span class="fa-solid fa-bars"></span>
</label>
<div class="navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../../index.html">
<img src="../../../_static/spark-logo-light.png" class="logo__image only-light" alt="Logo image"/>
<script>document.write(`<img src="../../../_static/spark-logo-dark.png" class="logo__image only-dark" alt="Logo image"/>`);</script>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item"><nav class="navbar-nav">
<p class="sidebar-header-items__title"
role="heading"
aria-level="1"
aria-label="Site Navigation">
Site Navigation
</p>
<ul class="bd-navbar-elements navbar-nav">
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../index.html">
Overview
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../development/index.html">
Development
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</nav></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<script>
document.write(`
<button class="btn btn-sm navbar-btn search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
</button>
`);
</script>
</div>
<div class="navbar-item"><!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
4.0.0-preview1
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/ml/classification.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script></div>
<div class="navbar-item">
<script>
document.write(`
<button class="theme-switch-button btn btn-sm btn-outline-primary navbar-btn rounded-circle" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="theme-switch" data-mode="light"><i class="fa-solid fa-sun"></i></span>
<span class="theme-switch" data-mode="dark"><i class="fa-solid fa-moon"></i></span>
<span class="theme-switch" data-mode="auto"><i class="fa-solid fa-circle-half-stroke"></i></span>
</button>
`);
</script></div>
<div class="navbar-item"><ul class="navbar-icon-links navbar-nav"
aria-label="Icon Links">
<li class="nav-item">
<a href="https://github.com/apache/spark" title="GitHub" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-brands fa-github"></i></span>
<label class="sr-only">GitHub</label></a>
</li>
<li class="nav-item">
<a href="https://pypi.org/project/pyspark" title="PyPI" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-solid fa-box"></i></span>
<label class="sr-only">PyPI</label></a>
</li>
</ul></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<script>
document.write(`
<button class="btn btn-sm navbar-btn search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
</button>
`);
</script>
</div>
</div>
</nav>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<div class="bd-sidebar-primary bd-sidebar hide-on-wide">
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item"><nav class="navbar-nav">
<p class="sidebar-header-items__title"
role="heading"
aria-level="1"
aria-label="Site Navigation">
Site Navigation
</p>
<ul class="bd-navbar-elements navbar-nav">
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../index.html">
Overview
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../development/index.html">
Development
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</nav></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item"><!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
4.0.0-preview1
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/ml/classification.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script></div>
<div class="navbar-item">
<script>
document.write(`
<button class="theme-switch-button btn btn-sm btn-outline-primary navbar-btn rounded-circle" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="theme-switch" data-mode="light"><i class="fa-solid fa-sun"></i></span>
<span class="theme-switch" data-mode="dark"><i class="fa-solid fa-moon"></i></span>
<span class="theme-switch" data-mode="auto"><i class="fa-solid fa-circle-half-stroke"></i></span>
</button>
`);
</script></div>
<div class="navbar-item"><ul class="navbar-icon-links navbar-nav"
aria-label="Icon Links">
<li class="nav-item">
<a href="https://github.com/apache/spark" title="GitHub" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-brands fa-github"></i></span>
<label class="sr-only">GitHub</label></a>
</li>
<li class="nav-item">
<a href="https://pypi.org/project/pyspark" title="PyPI" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-solid fa-box"></i></span>
<label class="sr-only">PyPI</label></a>
</li>
</ul></div>
</div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
<div id="rtd-footer-container"></div>
</div>
<main id="main-content" class="bd-main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumbs">
<ul class="bd-breadcrumbs" role="navigation" aria-label="Breadcrumb">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page">pyspark.ml.classification</li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article" role="main">
<h1>Source code for pyspark.ml.classification</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span>
<span class="c1"># contributor license agreements. See the NOTICE file distributed with</span>
<span class="c1"># this work for additional information regarding copyright ownership.</span>
<span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span>
<span class="c1"># (the &quot;License&quot;); you may not use this file except in compliance with</span>
<span class="c1"># the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">os</span>
<span class="kn">import</span> <span class="nn">operator</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">uuid</span>
<span class="kn">import</span> <span class="nn">warnings</span>
<span class="kn">from</span> <span class="nn">abc</span> <span class="kn">import</span> <span class="n">ABCMeta</span><span class="p">,</span> <span class="n">abstractmethod</span>
<span class="kn">from</span> <span class="nn">multiprocessing.pool</span> <span class="kn">import</span> <span class="n">ThreadPool</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">Any</span><span class="p">,</span>
<span class="n">Dict</span><span class="p">,</span>
<span class="n">Generic</span><span class="p">,</span>
<span class="n">Iterable</span><span class="p">,</span>
<span class="n">List</span><span class="p">,</span>
<span class="n">Optional</span><span class="p">,</span>
<span class="n">Type</span><span class="p">,</span>
<span class="n">TypeVar</span><span class="p">,</span>
<span class="n">Union</span><span class="p">,</span>
<span class="n">cast</span><span class="p">,</span>
<span class="n">overload</span><span class="p">,</span>
<span class="n">TYPE_CHECKING</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">keyword_only</span><span class="p">,</span> <span class="n">since</span><span class="p">,</span> <span class="n">inheritable_thread_target</span>
<span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Estimator</span><span class="p">,</span> <span class="n">Predictor</span><span class="p">,</span> <span class="n">PredictionModel</span><span class="p">,</span> <span class="n">Model</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.param.shared</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">HasRawPredictionCol</span><span class="p">,</span>
<span class="n">HasProbabilityCol</span><span class="p">,</span>
<span class="n">HasThresholds</span><span class="p">,</span>
<span class="n">HasRegParam</span><span class="p">,</span>
<span class="n">HasMaxIter</span><span class="p">,</span>
<span class="n">HasFitIntercept</span><span class="p">,</span>
<span class="n">HasTol</span><span class="p">,</span>
<span class="n">HasStandardization</span><span class="p">,</span>
<span class="n">HasWeightCol</span><span class="p">,</span>
<span class="n">HasAggregationDepth</span><span class="p">,</span>
<span class="n">HasThreshold</span><span class="p">,</span>
<span class="n">HasBlockSize</span><span class="p">,</span>
<span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span>
<span class="n">Param</span><span class="p">,</span>
<span class="n">Params</span><span class="p">,</span>
<span class="n">TypeConverters</span><span class="p">,</span>
<span class="n">HasElasticNetParam</span><span class="p">,</span>
<span class="n">HasSeed</span><span class="p">,</span>
<span class="n">HasStepSize</span><span class="p">,</span>
<span class="n">HasSolver</span><span class="p">,</span>
<span class="n">HasParallelism</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.tree</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">_DecisionTreeModel</span><span class="p">,</span>
<span class="n">_DecisionTreeParams</span><span class="p">,</span>
<span class="n">_TreeEnsembleModel</span><span class="p">,</span>
<span class="n">_RandomForestParams</span><span class="p">,</span>
<span class="n">_GBTParams</span><span class="p">,</span>
<span class="n">_HasVarianceImpurity</span><span class="p">,</span>
<span class="n">_TreeClassifierParams</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">_FactorizationMachinesParams</span><span class="p">,</span> <span class="n">DecisionTreeRegressionModel</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.base</span> <span class="kn">import</span> <span class="n">_PredictorParams</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.util</span> <span class="kn">import</span> <span class="p">(</span>
<span class="n">DefaultParamsReader</span><span class="p">,</span>
<span class="n">DefaultParamsWriter</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">,</span>
<span class="n">JavaMLReader</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLWriter</span><span class="p">,</span>
<span class="n">MLReader</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">,</span>
<span class="n">MLWriter</span><span class="p">,</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="n">HasTrainingSummary</span><span class="p">,</span>
<span class="p">)</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.wrapper</span> <span class="kn">import</span> <span class="n">JavaParams</span><span class="p">,</span> <span class="n">JavaPredictor</span><span class="p">,</span> <span class="n">JavaPredictionModel</span><span class="p">,</span> <span class="n">JavaWrapper</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.common</span> <span class="kn">import</span> <span class="n">inherit_doc</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.linalg</span> <span class="kn">import</span> <span class="n">Matrix</span><span class="p">,</span> <span class="n">Vector</span><span class="p">,</span> <span class="n">Vectors</span><span class="p">,</span> <span class="n">VectorUDT</span>
<span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">DataFrame</span><span class="p">,</span> <span class="n">Row</span>
<span class="kn">from</span> <span class="nn">pyspark.sql.functions</span> <span class="kn">import</span> <span class="n">udf</span><span class="p">,</span> <span class="n">when</span>
<span class="kn">from</span> <span class="nn">pyspark.sql.types</span> <span class="kn">import</span> <span class="n">ArrayType</span><span class="p">,</span> <span class="n">DoubleType</span>
<span class="kn">from</span> <span class="nn">pyspark.storagelevel</span> <span class="kn">import</span> <span class="n">StorageLevel</span>
<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">pyspark.ml._typing</span> <span class="kn">import</span> <span class="n">P</span><span class="p">,</span> <span class="n">ParamMap</span>
<span class="kn">from</span> <span class="nn">py4j.java_gateway</span> <span class="kn">import</span> <span class="n">JavaObject</span>
<span class="kn">from</span> <span class="nn">pyspark.core.context</span> <span class="kn">import</span> <span class="n">SparkContext</span>
<span class="n">T</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">&quot;T&quot;</span><span class="p">)</span>
<span class="n">JPM</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">&quot;JPM&quot;</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="n">JavaPredictionModel</span><span class="p">)</span>
<span class="n">CM</span> <span class="o">=</span> <span class="n">TypeVar</span><span class="p">(</span><span class="s2">&quot;CM&quot;</span><span class="p">,</span> <span class="n">bound</span><span class="o">=</span><span class="s2">&quot;ClassificationModel&quot;</span><span class="p">)</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span>
<span class="s2">&quot;LinearSVC&quot;</span><span class="p">,</span>
<span class="s2">&quot;LinearSVCModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;LinearSVCSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;LinearSVCTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;LogisticRegression&quot;</span><span class="p">,</span>
<span class="s2">&quot;LogisticRegressionModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;LogisticRegressionSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;LogisticRegressionTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;BinaryLogisticRegressionSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;BinaryLogisticRegressionTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">,</span>
<span class="s2">&quot;DecisionTreeClassificationModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;GBTClassifier&quot;</span><span class="p">,</span>
<span class="s2">&quot;GBTClassificationModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForestClassificationModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForestClassificationSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForestClassificationTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;BinaryRandomForestClassificationSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;BinaryRandomForestClassificationTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;NaiveBayes&quot;</span><span class="p">,</span>
<span class="s2">&quot;NaiveBayesModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">,</span>
<span class="s2">&quot;MultilayerPerceptronClassificationModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;MultilayerPerceptronClassificationSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;MultilayerPerceptronClassificationTrainingSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;OneVsRest&quot;</span><span class="p">,</span>
<span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;FMClassifier&quot;</span><span class="p">,</span>
<span class="s2">&quot;FMClassificationModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;FMClassificationSummary&quot;</span><span class="p">,</span>
<span class="s2">&quot;FMClassificationTrainingSummary&quot;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">class</span> <span class="nc">_ClassifierParams</span><span class="p">(</span><span class="n">HasRawPredictionCol</span><span class="p">,</span> <span class="n">_PredictorParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Classifier Params for classification tasks.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">Classifier</span><span class="p">(</span><span class="n">Predictor</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span> <span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">Generic</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Classifier for classification tasks.</span>
<span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">ClassificationModel</span><span class="p">(</span><span class="n">PredictionModel</span><span class="p">,</span> <span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model produced by a ``Classifier``.</span>
<span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@abstractmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">numClasses</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Number of classes (values which the label can take).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
<span class="nd">@abstractmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictRaw</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Raw prediction for each possible label.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
<span class="k">class</span> <span class="nc">_ProbabilisticClassifierParams</span><span class="p">(</span><span class="n">HasProbabilityCol</span><span class="p">,</span> <span class="n">HasThresholds</span><span class="p">,</span> <span class="n">_ClassifierParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`ProbabilisticClassifier` and</span>
<span class="sd"> :py:class:`ProbabilisticClassificationModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">ProbabilisticClassifier</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Probabilistic Classifier for classification tasks.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setProbabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`probabilityCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">probabilityCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`thresholds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">ProbabilisticClassificationModel</span><span class="p">(</span>
<span class="n">ClassificationModel</span><span class="p">,</span> <span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model produced by a ``ProbabilisticClassifier``.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setProbabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">CM</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">CM</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`probabilityCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">probabilityCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="n">CM</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">CM</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`thresholds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@abstractmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictProbability</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Predict the probability of each class given the features.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">raise</span> <span class="ne">NotImplementedError</span><span class="p">()</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_JavaClassifier</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">JavaPredictor</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">Generic</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Java Classifier for classification tasks.</span>
<span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_JavaClassificationModel</span><span class="p">(</span><span class="n">ClassificationModel</span><span class="p">,</span> <span class="n">JavaPredictionModel</span><span class="p">[</span><span class="n">T</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Java Model produced by a ``Classifier``.</span>
<span class="sd"> Classes are indexed {0, 1, ..., numClasses - 1}.</span>
<span class="sd"> To be mixed in with :class:`pyspark.ml.JavaModel`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">numClasses</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Number of classes (values which the label can take).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;numClasses&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictRaw</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Raw prediction for each possible label.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;predictRaw&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_JavaProbabilisticClassifier</span><span class="p">(</span>
<span class="n">ProbabilisticClassifier</span><span class="p">,</span> <span class="n">_JavaClassifier</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">Generic</span><span class="p">[</span><span class="n">JPM</span><span class="p">],</span> <span class="n">metaclass</span><span class="o">=</span><span class="n">ABCMeta</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Java Probabilistic Classifier for classification tasks.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_JavaProbabilisticClassificationModel</span><span class="p">(</span>
<span class="n">ProbabilisticClassificationModel</span><span class="p">,</span> <span class="n">_JavaClassificationModel</span><span class="p">[</span><span class="n">T</span><span class="p">]</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Java Model produced by a ``ProbabilisticClassifier``.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictProbability</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Predict the probability of each class given the features.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;predictProbability&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_ClassificationSummary</span><span class="p">(</span><span class="n">JavaWrapper</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for multiclass classification results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictions</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Dataframe outputted by the model&#39;s `transform` method.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;predictions&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">predictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the prediction of each class.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;predictionCol&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">labelCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the true label of each</span>
<span class="sd"> instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;labelCol&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the weight of each instance</span>
<span class="sd"> as a vector.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightCol&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">labels</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns the sequence of labels in ascending order. This order matches the order used</span>
<span class="sd"> in metrics which are specified as arrays over labels, e.g., truePositiveRateByLabel.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> In most cases, it will be values {0.0, 1.0, ..., numClasses-1}, However, if the</span>
<span class="sd"> training set is missing a label, then all of the arrays over labels</span>
<span class="sd"> (e.g., from truePositiveRateByLabel) will be of length numClasses-1 instead of the</span>
<span class="sd"> expected numClasses.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;labels&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">truePositiveRateByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns true positive rate for each label (category).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;truePositiveRateByLabel&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">falsePositiveRateByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns false positive rate for each label (category).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;falsePositiveRateByLabel&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">precisionByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns precision for each label (category).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;precisionByLabel&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">recallByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns recall for each label (category).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;recallByLabel&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fMeasureByLabel</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns f-measure for each label (category).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;fMeasureByLabel&quot;</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">accuracy</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns accuracy.</span>
<span class="sd"> (equals to the total number of correctly classified instances</span>
<span class="sd"> out of the total number of instances.)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;accuracy&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightedTruePositiveRate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns weighted true positive rate.</span>
<span class="sd"> (equals to precision, recall and f-measure)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightedTruePositiveRate&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightedFalsePositiveRate</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns weighted false positive rate.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightedFalsePositiveRate&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightedRecall</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns weighted averaged recall.</span>
<span class="sd"> (equals to precision, recall and f-measure)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightedRecall&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightedPrecision</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns weighted averaged precision.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightedPrecision&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weightedFMeasure</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns weighted averaged f-measure.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weightedFMeasure&quot;</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_TrainingSummary</span><span class="p">(</span><span class="n">JavaWrapper</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for Training results.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">objectiveHistory</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Objective function (scaled loss + regularization) at each</span>
<span class="sd"> iteration. It contains one more element, the initial state,</span>
<span class="sd"> than number of iterations.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;objectiveHistory&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">totalIterations</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Number of training iterations until termination.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;totalIterations&quot;</span><span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_BinaryClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Binary classification results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">scoreCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the probability or raw prediction</span>
<span class="sd"> of each class as a vector.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;scoreCol&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">roc</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns the receiver operating characteristic (ROC) curve,</span>
<span class="sd"> which is a Dataframe having two fields (FPR, TPR) with</span>
<span class="sd"> (0.0, 0.0) prepended and (1.0, 1.0) appended to it.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> `Wikipedia reference &lt;http://en.wikipedia.org/wiki/Receiver_operating_characteristic&gt;`_</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;roc&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">areaUnderROC</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Computes the area under the receiver operating characteristic</span>
<span class="sd"> (ROC) curve.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;areaUnderROC&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">pr</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns the precision-recall curve, which is a Dataframe</span>
<span class="sd"> containing two fields recall, precision with (0.0, 1.0) prepended</span>
<span class="sd"> to it.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;pr&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">fMeasureByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns a dataframe with two fields (threshold, F-Measure) curve</span>
<span class="sd"> with beta = 1.0.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;fMeasureByThreshold&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">precisionByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns a dataframe with two fields (threshold, precision) curve.</span>
<span class="sd"> Every possible probability obtained in transforming the dataset</span>
<span class="sd"> are used as thresholds used in calculating the precision.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;precisionByThreshold&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">recallByThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Returns a dataframe with two fields (threshold, recall) curve.</span>
<span class="sd"> Every possible probability obtained in transforming the dataset</span>
<span class="sd"> are used as thresholds used in calculating the recall.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;recallByThreshold&quot;</span><span class="p">)</span>
<span class="k">class</span> <span class="nc">_LinearSVCParams</span><span class="p">(</span>
<span class="n">_ClassifierParams</span><span class="p">,</span>
<span class="n">HasRegParam</span><span class="p">,</span>
<span class="n">HasMaxIter</span><span class="p">,</span>
<span class="n">HasFitIntercept</span><span class="p">,</span>
<span class="n">HasTol</span><span class="p">,</span>
<span class="n">HasStandardization</span><span class="p">,</span>
<span class="n">HasWeightCol</span><span class="p">,</span>
<span class="n">HasAggregationDepth</span><span class="p">,</span>
<span class="n">HasThreshold</span><span class="p">,</span>
<span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`LinearSVC` and :py:class:`LinearSVCModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">threshold</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;threshold&quot;</span><span class="p">,</span>
<span class="s2">&quot;The threshold in binary classification applied to the linear model&quot;</span>
<span class="s2">&quot; prediction. This threshold can be any real number, where Inf will make&quot;</span>
<span class="s2">&quot; all predictions 0.0 and -Inf will make all predictions 1.0.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_LinearSVCParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span>
<span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span>
<span class="n">regParam</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">standardization</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="n">threshold</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span>
<div class="viewcode-block" id="LinearSVC"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">LinearSVC</span><span class="p">(</span>
<span class="n">_JavaClassifier</span><span class="p">[</span><span class="s2">&quot;LinearSVCModel&quot;</span><span class="p">],</span>
<span class="n">_LinearSVCParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;LinearSVC&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.</span>
<span class="sd"> Only supports L2 regularization currently.</span>
<span class="sd"> .. versionadded:: 2.2.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> `Linear SVM Classifier &lt;https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM&gt;`_</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql import Row</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; df = sc.parallelize([</span>
<span class="sd"> ... Row(label=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),</span>
<span class="sd"> ... Row(label=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; svm = LinearSVC()</span>
<span class="sd"> &gt;&gt;&gt; svm.getMaxIter()</span>
<span class="sd"> 100</span>
<span class="sd"> &gt;&gt;&gt; svm.setMaxIter(5)</span>
<span class="sd"> LinearSVC...</span>
<span class="sd"> &gt;&gt;&gt; svm.getMaxIter()</span>
<span class="sd"> 5</span>
<span class="sd"> &gt;&gt;&gt; svm.getRegParam()</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; svm.setRegParam(0.01)</span>
<span class="sd"> LinearSVC...</span>
<span class="sd"> &gt;&gt;&gt; svm.getRegParam()</span>
<span class="sd"> 0.01</span>
<span class="sd"> &gt;&gt;&gt; model = svm.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model.setPredictionCol(&quot;newPrediction&quot;)</span>
<span class="sd"> LinearSVCModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getPredictionCol()</span>
<span class="sd"> &#39;newPrediction&#39;</span>
<span class="sd"> &gt;&gt;&gt; model.setThreshold(0.5)</span>
<span class="sd"> LinearSVCModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getThreshold()</span>
<span class="sd"> 0.5</span>
<span class="sd"> &gt;&gt;&gt; model.getMaxBlockSizeInMB()</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model.coefficients</span>
<span class="sd"> DenseVector([0.0, -1.0319, -0.5159])</span>
<span class="sd"> &gt;&gt;&gt; model.intercept</span>
<span class="sd"> 2.579645978780695</span>
<span class="sd"> &gt;&gt;&gt; model.numClasses</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; model.numFeatures</span>
<span class="sd"> 3</span>
<span class="sd"> &gt;&gt;&gt; test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, -1.0, -1.0))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.predict(test0.head().features)</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([-4.1274, 4.1274])</span>
<span class="sd"> &gt;&gt;&gt; result = model.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.newPrediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; result.rawPrediction</span>
<span class="sd"> DenseVector([-4.1274, 4.1274])</span>
<span class="sd"> &gt;&gt;&gt; svm_path = temp_path + &quot;/svm&quot;</span>
<span class="sd"> &gt;&gt;&gt; svm.save(svm_path)</span>
<span class="sd"> &gt;&gt;&gt; svm2 = LinearSVC.load(svm_path)</span>
<span class="sd"> &gt;&gt;&gt; svm2.getMaxIter()</span>
<span class="sd"> 5</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/svm_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = LinearSVCModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.coefficients[0] == model2.coefficients[0]</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.intercept == model2.intercept</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \</span>
<span class="sd"> aggregationDepth=2, maxBlockSizeInMB=0.0):</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LinearSVC</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.LinearSVC&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="LinearSVC.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, regParam=0.0, tol=1e-6, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> fitIntercept=True, standardization=True, threshold=0.0, weightCol=None, \</span>
<span class="sd"> aggregationDepth=2, maxBlockSizeInMB=0.0):</span>
<span class="sd"> Sets params for Linear SVM Classifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVCModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">LinearSVCModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="LinearSVC.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxIter`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setRegParam">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`regParam`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setTol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`tol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setFitIntercept">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setStandardization"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setStandardization">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setStandardization</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`standardization`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">standardization</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setThreshold"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setThreshold">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`threshold`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setAggregationDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setAggregationDepth">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setAggregationDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`aggregationDepth`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">aggregationDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVC.setMaxBlockSizeInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVC.html#pyspark.ml.classification.LinearSVC.setMaxBlockSizeInMB">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMaxBlockSizeInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVC&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxBlockSizeInMB`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="LinearSVCModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel">[docs]</a><span class="k">class</span> <span class="nc">LinearSVCModel</span><span class="p">(</span>
<span class="n">_JavaClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_LinearSVCParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;LinearSVCModel&quot;</span><span class="p">],</span>
<span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">&quot;LinearSVCTrainingSummary&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by LinearSVC.</span>
<span class="sd"> .. versionadded:: 2.2.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<div class="viewcode-block" id="LinearSVCModel.setThreshold"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.setThreshold">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVCModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`threshold`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">coefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model coefficients of Linear SVM Classifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;coefficients&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model intercept of Linear SVM Classifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;intercept&quot;</span><span class="p">)</span>
<div class="viewcode-block" id="LinearSVCModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVCTrainingSummary&quot;</span><span class="p">:</span> <span class="c1"># type: ignore[override]</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span>
<span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span>
<span class="k">return</span> <span class="n">LinearSVCTrainingSummary</span><span class="p">(</span><span class="nb">super</span><span class="p">(</span><span class="n">LinearSVCModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;No training summary available for this </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="LinearSVCModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCModel.html#pyspark.ml.classification.LinearSVCModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LinearSVCSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Evaluates the model on a test dataset.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">java_lsvc_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluate&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="k">return</span> <span class="n">LinearSVCSummary</span><span class="p">(</span><span class="n">java_lsvc_summary</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="LinearSVCSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCSummary.html#pyspark.ml.classification.LinearSVCSummary">[docs]</a><span class="k">class</span> <span class="nc">LinearSVCSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for LinearSVC Results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="LinearSVCTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LinearSVCTrainingSummary.html#pyspark.ml.classification.LinearSVCTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">LinearSVCTrainingSummary</span><span class="p">(</span><span class="n">LinearSVCSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for LinearSVC Training results.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<span class="k">class</span> <span class="nc">_LogisticRegressionParams</span><span class="p">(</span>
<span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span>
<span class="n">HasRegParam</span><span class="p">,</span>
<span class="n">HasElasticNetParam</span><span class="p">,</span>
<span class="n">HasMaxIter</span><span class="p">,</span>
<span class="n">HasFitIntercept</span><span class="p">,</span>
<span class="n">HasTol</span><span class="p">,</span>
<span class="n">HasStandardization</span><span class="p">,</span>
<span class="n">HasWeightCol</span><span class="p">,</span>
<span class="n">HasAggregationDepth</span><span class="p">,</span>
<span class="n">HasThreshold</span><span class="p">,</span>
<span class="n">HasMaxBlockSizeInMB</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`LogisticRegression` and :py:class:`LogisticRegressionModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">threshold</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;threshold&quot;</span><span class="p">,</span>
<span class="s2">&quot;Threshold in binary classification prediction, in range [0, 1].&quot;</span>
<span class="o">+</span> <span class="s2">&quot; If threshold and thresholds are both set, they must match.&quot;</span>
<span class="o">+</span> <span class="s2">&quot;e.g. if threshold is p, then thresholds must be equal to [1-p, p].&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">family</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;family&quot;</span><span class="p">,</span>
<span class="s2">&quot;The name of family which is a description of the label distribution to &quot;</span>
<span class="o">+</span> <span class="s2">&quot;be used in the model. Supported options: auto, binomial, multinomial&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;lowerBoundsOnCoefficients&quot;</span><span class="p">,</span>
<span class="s2">&quot;The lower bounds on coefficients if fitting under bound &quot;</span>
<span class="s2">&quot;constrained optimization. The bound matrix must be &quot;</span>
<span class="s2">&quot;compatible with the shape &quot;</span>
<span class="s2">&quot;(1, number of features) for binomial regression, or &quot;</span>
<span class="s2">&quot;(number of classes, number of features) &quot;</span>
<span class="s2">&quot;for multinomial regression.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toMatrix</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;upperBoundsOnCoefficients&quot;</span><span class="p">,</span>
<span class="s2">&quot;The upper bounds on coefficients if fitting under bound &quot;</span>
<span class="s2">&quot;constrained optimization. The bound matrix must be &quot;</span>
<span class="s2">&quot;compatible with the shape &quot;</span>
<span class="s2">&quot;(1, number of features) for binomial regression, or &quot;</span>
<span class="s2">&quot;(number of classes, number of features) &quot;</span>
<span class="s2">&quot;for multinomial regression.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toMatrix</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;lowerBoundsOnIntercepts&quot;</span><span class="p">,</span>
<span class="s2">&quot;The lower bounds on intercepts if fitting under bound &quot;</span>
<span class="s2">&quot;constrained optimization. The bounds vector size must be&quot;</span>
<span class="s2">&quot;equal with 1 for binomial regression, or the number of&quot;</span>
<span class="s2">&quot;lasses for multinomial regression.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;upperBoundsOnIntercepts&quot;</span><span class="p">,</span>
<span class="s2">&quot;The upper bounds on intercepts if fitting under bound &quot;</span>
<span class="s2">&quot;constrained optimization. The bound vector size must be &quot;</span>
<span class="s2">&quot;equal with 1 for binomial regression, or the number of &quot;</span>
<span class="s2">&quot;classes for multinomial regression.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_LogisticRegressionParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span>
<span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">family</span><span class="o">=</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span> <span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="mf">0.0</span>
<span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`threshold`.</span>
<span class="sd"> Clears value of :py:attr:`thresholds` if it has been set.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">threshold</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="k">return</span> <span class="bp">self</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getThreshold</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get threshold for binary classification.</span>
<span class="sd"> If :py:attr:`thresholds` is set with length 2 (i.e., binary classification),</span>
<span class="sd"> this returns the equivalent threshold:</span>
<span class="sd"> :math:`\\frac{1}{1 + \\frac{thresholds(0)}{thresholds(1)}}`.</span>
<span class="sd"> Otherwise, returns :py:attr:`threshold` if set or its default value if unset.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">):</span>
<span class="n">ts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ts</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Logistic Regression getThreshold only applies to&quot;</span>
<span class="o">+</span> <span class="s2">&quot; binary classification, but thresholds has length != 2.&quot;</span>
<span class="o">+</span> <span class="s2">&quot; thresholds: </span><span class="si">{ts}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">ts</span><span class="o">=</span><span class="n">ts</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">return</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">ts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">ts</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">:</span> <span class="s2">&quot;P&quot;</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;P&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`thresholds`.</span>
<span class="sd"> Clears value of :py:attr:`threshold` if it has been set.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">thresholds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">clear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> <span class="c1"># type: ignore[attr-defined]</span>
<span class="k">return</span> <span class="bp">self</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getThresholds</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> If :py:attr:`thresholds` is set, return its value.</span>
<span class="sd"> Otherwise, if :py:attr:`threshold` is set, return the equivalent thresholds for binary</span>
<span class="sd"> classification: (1-threshold, threshold).</span>
<span class="sd"> If neither are set, throw an error.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span>
<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">):</span>
<span class="n">t</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span>
<span class="k">return</span> <span class="p">[</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">t</span><span class="p">,</span> <span class="n">t</span><span class="p">]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">_checkThresholdConsistency</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">):</span>
<span class="n">ts</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">thresholds</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ts</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Logistic Regression getThreshold only applies to&quot;</span>
<span class="o">+</span> <span class="s2">&quot; binary classification, but thresholds has length != 2.&quot;</span>
<span class="o">+</span> <span class="s2">&quot; thresholds: </span><span class="si">{0}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">ts</span><span class="p">))</span>
<span class="p">)</span>
<span class="n">t</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">+</span> <span class="n">ts</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="n">ts</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span>
<span class="n">t2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">)</span>
<span class="k">if</span> <span class="nb">abs</span><span class="p">(</span><span class="n">t2</span> <span class="o">-</span> <span class="n">t</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mf">1e-5</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="s2">&quot;Logistic Regression getThreshold found inconsistent values for&quot;</span>
<span class="o">+</span> <span class="s2">&quot; threshold (</span><span class="si">%g</span><span class="s2">) and thresholds (equivalent to </span><span class="si">%g</span><span class="s2">)&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">t2</span><span class="p">,</span> <span class="n">t</span><span class="p">)</span>
<span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getFamily</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of :py:attr:`family` or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">family</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getLowerBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of :py:attr:`lowerBoundsOnCoefficients`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lowerBoundsOnCoefficients</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getUpperBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of :py:attr:`upperBoundsOnCoefficients`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">upperBoundsOnCoefficients</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getLowerBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of :py:attr:`lowerBoundsOnIntercepts`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lowerBoundsOnIntercepts</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getUpperBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of :py:attr:`upperBoundsOnIntercepts`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">upperBoundsOnIntercepts</span><span class="p">)</span>
<div class="viewcode-block" id="LogisticRegression"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">LogisticRegression</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;LogisticRegressionModel&quot;</span><span class="p">],</span>
<span class="n">_LogisticRegressionParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;LogisticRegression&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Logistic regression.</span>
<span class="sd"> This class supports multinomial logistic (softmax) and binomial logistic regression.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql import Row</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; bdf = sc.parallelize([</span>
<span class="sd"> ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)),</span>
<span class="sd"> ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)),</span>
<span class="sd"> ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)),</span>
<span class="sd"> ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; blor = LogisticRegression(weightCol=&quot;weight&quot;)</span>
<span class="sd"> &gt;&gt;&gt; blor.getRegParam()</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; blor.setRegParam(0.01)</span>
<span class="sd"> LogisticRegression...</span>
<span class="sd"> &gt;&gt;&gt; blor.getRegParam()</span>
<span class="sd"> 0.01</span>
<span class="sd"> &gt;&gt;&gt; blor.setMaxIter(10)</span>
<span class="sd"> LogisticRegression...</span>
<span class="sd"> &gt;&gt;&gt; blor.getMaxIter()</span>
<span class="sd"> 10</span>
<span class="sd"> &gt;&gt;&gt; blor.clear(blor.maxIter)</span>
<span class="sd"> &gt;&gt;&gt; blorModel = blor.fit(bdf)</span>
<span class="sd"> &gt;&gt;&gt; blorModel.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> LogisticRegressionModel...</span>
<span class="sd"> &gt;&gt;&gt; blorModel.setProbabilityCol(&quot;newProbability&quot;)</span>
<span class="sd"> LogisticRegressionModel...</span>
<span class="sd"> &gt;&gt;&gt; blorModel.getProbabilityCol()</span>
<span class="sd"> &#39;newProbability&#39;</span>
<span class="sd"> &gt;&gt;&gt; blorModel.getMaxBlockSizeInMB()</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; blorModel.setThreshold(0.1)</span>
<span class="sd"> LogisticRegressionModel...</span>
<span class="sd"> &gt;&gt;&gt; blorModel.getThreshold()</span>
<span class="sd"> 0.1</span>
<span class="sd"> &gt;&gt;&gt; blorModel.coefficients</span>
<span class="sd"> DenseVector([-1.080..., -0.646...])</span>
<span class="sd"> &gt;&gt;&gt; blorModel.intercept</span>
<span class="sd"> 3.112...</span>
<span class="sd"> &gt;&gt;&gt; blorModel.evaluate(bdf).accuracy == blorModel.summary.accuracy</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; data_path = &quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span>
<span class="sd"> &gt;&gt;&gt; mdf = spark.read.format(&quot;libsvm&quot;).load(data_path)</span>
<span class="sd"> &gt;&gt;&gt; mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family=&quot;multinomial&quot;)</span>
<span class="sd"> &gt;&gt;&gt; mlorModel = mlor.fit(mdf)</span>
<span class="sd"> &gt;&gt;&gt; mlorModel.coefficientMatrix</span>
<span class="sd"> SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1)</span>
<span class="sd"> &gt;&gt;&gt; mlorModel.interceptVector</span>
<span class="sd"> DenseVector([0.04..., -0.42..., 0.37...])</span>
<span class="sd"> &gt;&gt;&gt; test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; blorModel.predict(test0.head().features)</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; blorModel.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([-3.54..., 3.54...])</span>
<span class="sd"> &gt;&gt;&gt; blorModel.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([0.028, 0.972])</span>
<span class="sd"> &gt;&gt;&gt; result = blorModel.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; result.newProbability</span>
<span class="sd"> DenseVector([0.02..., 0.97...])</span>
<span class="sd"> &gt;&gt;&gt; result.rawPrediction</span>
<span class="sd"> DenseVector([-3.54..., 3.54...])</span>
<span class="sd"> &gt;&gt;&gt; test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; blorModel.transform(test1).head().prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; blor.setParams(&quot;vector&quot;)</span>
<span class="sd"> Traceback (most recent call last):</span>
<span class="sd"> ...</span>
<span class="sd"> TypeError: Method setParams forces keyword arguments.</span>
<span class="sd"> &gt;&gt;&gt; lr_path = temp_path + &quot;/lr&quot;</span>
<span class="sd"> &gt;&gt;&gt; blor.save(lr_path)</span>
<span class="sd"> &gt;&gt;&gt; lr2 = LogisticRegression.load(lr_path)</span>
<span class="sd"> &gt;&gt;&gt; lr2.getRegParam()</span>
<span class="sd"> 0.01</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/lr_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; blorModel.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = LogisticRegressionModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; blorModel.coefficients[0] == model2.coefficients[0]</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; blorModel.intercept == model2.intercept</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model2</span>
<span class="sd"> LogisticRegressionModel: uid=..., numClasses=2, numFeatures=2</span>
<span class="sd"> &gt;&gt;&gt; blorModel.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="p">):</span>
<span class="o">...</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="p">):</span>
<span class="o">...</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \</span>
<span class="sd"> threshold=0.5, thresholds=None, probabilityCol=&quot;probability&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;, standardization=True, weightCol=None, \</span>
<span class="sd"> aggregationDepth=2, family=&quot;auto&quot;, \</span>
<span class="sd"> lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \</span>
<span class="sd"> lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \</span>
<span class="sd"> maxBlockSizeInMB=0.0):</span>
<span class="sd"> If the threshold and thresholds Params are both set, they must be equivalent.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegression</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.LogisticRegression&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="o">...</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="o">...</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="o">...</span>
<div class="viewcode-block" id="LogisticRegression.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">elasticNetParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">threshold</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.5</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">standardization</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">aggregationDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span><span class="p">,</span>
<span class="n">family</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">lowerBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">upperBoundsOnCoefficients</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Matrix</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">lowerBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">upperBoundsOnIntercepts</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">maxBlockSizeInMB</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, regParam=0.0, elasticNetParam=0.0, tol=1e-6, fitIntercept=True, \</span>
<span class="sd"> threshold=0.5, thresholds=None, probabilityCol=&quot;probability&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;, standardization=True, weightCol=None, \</span>
<span class="sd"> aggregationDepth=2, family=&quot;auto&quot;, \</span>
<span class="sd"> lowerBoundsOnCoefficients=None, upperBoundsOnCoefficients=None, \</span>
<span class="sd"> lowerBoundsOnIntercepts=None, upperBoundsOnIntercepts=None, \</span>
<span class="sd"> maxBlockSizeInMB=0.0):</span>
<span class="sd"> Sets params for logistic regression.</span>
<span class="sd"> If the threshold and thresholds Params are both set, they must be equivalent.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_checkThresholdConsistency</span><span class="p">()</span>
<span class="k">return</span> <span class="bp">self</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegressionModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">LogisticRegressionModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="LogisticRegression.setFamily"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setFamily">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFamily</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`family`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">family</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setLowerBoundsOnCoefficients"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setLowerBoundsOnCoefficients">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setLowerBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Matrix</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`lowerBoundsOnCoefficients`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lowerBoundsOnCoefficients</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setUpperBoundsOnCoefficients"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setUpperBoundsOnCoefficients">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setUpperBoundsOnCoefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Matrix</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`upperBoundsOnCoefficients`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">upperBoundsOnCoefficients</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setLowerBoundsOnIntercepts"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setLowerBoundsOnIntercepts">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setLowerBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`lowerBoundsOnIntercepts`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lowerBoundsOnIntercepts</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setUpperBoundsOnIntercepts"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setUpperBoundsOnIntercepts">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setUpperBoundsOnIntercepts</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`upperBoundsOnIntercepts`</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">upperBoundsOnIntercepts</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setMaxIter">[docs]</a> <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxIter`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setRegParam">[docs]</a> <span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`regParam`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setTol">[docs]</a> <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`tol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setElasticNetParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setElasticNetParam">[docs]</a> <span class="k">def</span> <span class="nf">setElasticNetParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`elasticNetParam`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">elasticNetParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setFitIntercept">[docs]</a> <span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setStandardization"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setStandardization">[docs]</a> <span class="k">def</span> <span class="nf">setStandardization</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`standardization`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">standardization</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setAggregationDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setAggregationDepth">[docs]</a> <span class="k">def</span> <span class="nf">setAggregationDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`aggregationDepth`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">aggregationDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegression.setMaxBlockSizeInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegression.html#pyspark.ml.classification.LogisticRegression.setMaxBlockSizeInMB">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMaxBlockSizeInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegression&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxBlockSizeInMB`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBlockSizeInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="LogisticRegressionModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionModel.html#pyspark.ml.classification.LogisticRegressionModel">[docs]</a><span class="k">class</span> <span class="nc">LogisticRegressionModel</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_LogisticRegressionParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;LogisticRegressionModel&quot;</span><span class="p">],</span>
<span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">&quot;LogisticRegressionTrainingSummary&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by LogisticRegression.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">coefficients</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model coefficients of binomial logistic regression.</span>
<span class="sd"> An exception is thrown in the case of multinomial logistic regression.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;coefficients&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model intercept of binomial logistic regression.</span>
<span class="sd"> An exception is thrown in the case of multinomial logistic regression.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;intercept&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">coefficientMatrix</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model coefficients.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;coefficientMatrix&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">interceptVector</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model intercept.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;interceptVector&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegressionTrainingSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span>
<span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">BinaryLogisticRegressionTrainingSummary</span><span class="p">(</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegressionModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">LogisticRegressionTrainingSummary</span><span class="p">(</span>
<span class="nb">super</span><span class="p">(</span><span class="n">LogisticRegressionModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;No training summary available for this </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="p">)</span>
<div class="viewcode-block" id="LogisticRegressionModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionModel.html#pyspark.ml.classification.LogisticRegressionModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;LogisticRegressionSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Evaluates the model on a test dataset.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">java_blr_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluate&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">BinaryLogisticRegressionSummary</span><span class="p">(</span><span class="n">java_blr_summary</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">LogisticRegressionSummary</span><span class="p">(</span><span class="n">java_blr_summary</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="LogisticRegressionSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionSummary.html#pyspark.ml.classification.LogisticRegressionSummary">[docs]</a><span class="k">class</span> <span class="nc">LogisticRegressionSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for Logistic Regression Results for a given model.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">probabilityCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the probability</span>
<span class="sd"> of each class as a vector.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;probabilityCol&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">featuresCol</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Field in &quot;predictions&quot; which gives the features of each instance</span>
<span class="sd"> as a vector.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;featuresCol&quot;</span><span class="p">)</span></div>
<div class="viewcode-block" id="LogisticRegressionTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.LogisticRegressionTrainingSummary.html#pyspark.ml.classification.LogisticRegressionTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">LogisticRegressionTrainingSummary</span><span class="p">(</span><span class="n">LogisticRegressionSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for multinomial Logistic Regression Training results.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="BinaryLogisticRegressionSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryLogisticRegressionSummary.html#pyspark.ml.classification.BinaryLogisticRegressionSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">BinaryLogisticRegressionSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">,</span> <span class="n">LogisticRegressionSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Binary Logistic regression results for a given model.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="BinaryLogisticRegressionTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary.html#pyspark.ml.classification.BinaryLogisticRegressionTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">BinaryLogisticRegressionTrainingSummary</span><span class="p">(</span>
<span class="n">BinaryLogisticRegressionSummary</span><span class="p">,</span> <span class="n">LogisticRegressionTrainingSummary</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Binary Logistic regression training results for a given model.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_DecisionTreeClassifierParams</span><span class="p">(</span><span class="n">_DecisionTreeParams</span><span class="p">,</span> <span class="n">_TreeClassifierParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`DecisionTreeClassifier` and :py:class:`DecisionTreeClassificationModel`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span>
<span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="o">=</span><span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">leafCol</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span>
<div class="viewcode-block" id="DecisionTreeClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">DecisionTreeClassifier</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;DecisionTreeClassificationModel&quot;</span><span class="p">],</span>
<span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> `Decision tree &lt;http://en.wikipedia.org/wiki/Decision_tree_learning&gt;`_</span>
<span class="sd"> learning algorithm for classification.</span>
<span class="sd"> It supports both binary and multiclass labels, as well as both continuous and categorical</span>
<span class="sd"> features.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.feature import StringIndexer</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (1.0, Vectors.dense(1.0)),</span>
<span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], [&quot;label&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; stringIndexer = StringIndexer(inputCol=&quot;label&quot;, outputCol=&quot;indexed&quot;)</span>
<span class="sd"> &gt;&gt;&gt; si_model = stringIndexer.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; td = si_model.transform(df)</span>
<span class="sd"> &gt;&gt;&gt; dt = DecisionTreeClassifier(maxDepth=2, labelCol=&quot;indexed&quot;, leafCol=&quot;leafId&quot;)</span>
<span class="sd"> &gt;&gt;&gt; model = dt.fit(td)</span>
<span class="sd"> &gt;&gt;&gt; model.getLabelCol()</span>
<span class="sd"> &#39;indexed&#39;</span>
<span class="sd"> &gt;&gt;&gt; model.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> DecisionTreeClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.numNodes</span>
<span class="sd"> 3</span>
<span class="sd"> &gt;&gt;&gt; model.depth</span>
<span class="sd"> 1</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances</span>
<span class="sd"> SparseVector(1, {0: 1.0})</span>
<span class="sd"> &gt;&gt;&gt; model.numFeatures</span>
<span class="sd"> 1</span>
<span class="sd"> &gt;&gt;&gt; model.numClasses</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; print(model.toDebugString)</span>
<span class="sd"> DecisionTreeClassificationModel...depth=1, numNodes=3...</span>
<span class="sd"> &gt;&gt;&gt; test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(test0.head().features)</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; result = model.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; result.probability</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; result.rawPrediction</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; result.leafId</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test1).head().prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; dtc_path = temp_path + &quot;/dtc&quot;</span>
<span class="sd"> &gt;&gt;&gt; dt.save(dtc_path)</span>
<span class="sd"> &gt;&gt;&gt; dt2 = DecisionTreeClassifier.load(dtc_path)</span>
<span class="sd"> &gt;&gt;&gt; dt2.getMaxDepth()</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/dtc_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = DecisionTreeClassificationModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances == model2.featureImportances</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; df3 = spark.createDataFrame([</span>
<span class="sd"> ... (1.0, 0.2, Vectors.dense(1.0)),</span>
<span class="sd"> ... (1.0, 0.8, Vectors.dense(1.0)),</span>
<span class="sd"> ... (0.0, 1.0, Vectors.sparse(1, [], []))], [&quot;label&quot;, &quot;weight&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; si3 = StringIndexer(inputCol=&quot;label&quot;, outputCol=&quot;indexed&quot;)</span>
<span class="sd"> &gt;&gt;&gt; si_model3 = si3.fit(df3)</span>
<span class="sd"> &gt;&gt;&gt; td3 = si_model3.transform(df3)</span>
<span class="sd"> &gt;&gt;&gt; dt3 = DecisionTreeClassifier(maxDepth=2, weightCol=&quot;weight&quot;, labelCol=&quot;indexed&quot;)</span>
<span class="sd"> &gt;&gt;&gt; model3 = dt3.fit(td3)</span>
<span class="sd"> &gt;&gt;&gt; print(model3.toDebugString)</span>
<span class="sd"> DecisionTreeClassificationModel...depth=1, numNodes=3...</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity=&quot;gini&quot;, \</span>
<span class="sd"> seed=None, weightCol=None, leafCol=&quot;&quot;, minWeightFractionPerNode=0.0)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">DecisionTreeClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.DecisionTreeClassifier&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="DecisionTreeClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity=&quot;gini&quot;, \</span>
<span class="sd"> seed=None, weightCol=None, leafCol=&quot;&quot;, minWeightFractionPerNode=0.0)</span>
<span class="sd"> Sets params for the DecisionTreeClassifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassificationModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">DecisionTreeClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="DecisionTreeClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxDepth`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxBins`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`impurity`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setCheckpointInterval">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTreeClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassifier.html#pyspark.ml.classification.DecisionTreeClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;DecisionTreeClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="DecisionTreeClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.DecisionTreeClassificationModel.html#pyspark.ml.classification.DecisionTreeClassificationModel">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">DecisionTreeClassificationModel</span><span class="p">(</span>
<span class="n">_DecisionTreeModel</span><span class="p">,</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_DecisionTreeClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;DecisionTreeClassificationModel&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by DecisionTreeClassifier.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Estimate of the importance of each feature.</span>
<span class="sd"> This generalizes the idea of &quot;Gini&quot; importance to other losses,</span>
<span class="sd"> following the explanation of Gini importance from &quot;Random Forests&quot; documentation</span>
<span class="sd"> by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.</span>
<span class="sd"> This feature importance is calculated as follows:</span>
<span class="sd"> - importance(feature j) = sum (over nodes which split on feature j) of the gain,</span>
<span class="sd"> where gain is scaled by the number of instances passing through node</span>
<span class="sd"> - Normalize importances for tree to sum to 1.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> Feature importance for single decision trees can have high variance due to</span>
<span class="sd"> correlated predictor variables. Consider using a :py:class:`RandomForestClassifier`</span>
<span class="sd"> to determine feature importance instead.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;featureImportances&quot;</span><span class="p">)</span></div>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">_RandomForestClassifierParams</span><span class="p">(</span><span class="n">_RandomForestParams</span><span class="p">,</span> <span class="n">_TreeClassifierParams</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`RandomForestClassifier` and :py:class:`RandomForestClassificationModel`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_RandomForestClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span>
<span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="o">=</span><span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">numTrees</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
<span class="n">leafCol</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">bootstrap</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
<span class="p">)</span>
<div class="viewcode-block" id="RandomForestClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">RandomForestClassifier</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;RandomForestClassificationModel&quot;</span><span class="p">],</span>
<span class="n">_RandomForestClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> `Random Forest &lt;http://en.wikipedia.org/wiki/Random_forest&gt;`_</span>
<span class="sd"> learning algorithm for classification.</span>
<span class="sd"> It supports both binary and multiclass labels, as well as both continuous and categorical</span>
<span class="sd"> features.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; import numpy</span>
<span class="sd"> &gt;&gt;&gt; from numpy import allclose</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.feature import StringIndexer</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (1.0, Vectors.dense(1.0)),</span>
<span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], [&quot;label&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; stringIndexer = StringIndexer(inputCol=&quot;label&quot;, outputCol=&quot;indexed&quot;)</span>
<span class="sd"> &gt;&gt;&gt; si_model = stringIndexer.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; td = si_model.transform(df)</span>
<span class="sd"> &gt;&gt;&gt; rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol=&quot;indexed&quot;, seed=42,</span>
<span class="sd"> ... leafCol=&quot;leafId&quot;)</span>
<span class="sd"> &gt;&gt;&gt; rf.getMinWeightFractionPerNode()</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model = rf.fit(td)</span>
<span class="sd"> &gt;&gt;&gt; model.getLabelCol()</span>
<span class="sd"> &#39;indexed&#39;</span>
<span class="sd"> &gt;&gt;&gt; model.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> RandomForestClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.setRawPredictionCol(&quot;newRawPrediction&quot;)</span>
<span class="sd"> RandomForestClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getBootstrap()</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.getRawPredictionCol()</span>
<span class="sd"> &#39;newRawPrediction&#39;</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances</span>
<span class="sd"> SparseVector(1, {0: 1.0})</span>
<span class="sd"> &gt;&gt;&gt; allclose(model.treeWeights, [1.0, 1.0, 1.0])</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(test0.head().features)</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([2.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; result = model.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; numpy.argmax(result.probability)</span>
<span class="sd"> 0</span>
<span class="sd"> &gt;&gt;&gt; numpy.argmax(result.newRawPrediction)</span>
<span class="sd"> 0</span>
<span class="sd"> &gt;&gt;&gt; result.leafId</span>
<span class="sd"> DenseVector([0.0, 0.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test1).head().prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.trees</span>
<span class="sd"> [DecisionTreeClassificationModel...depth=..., DecisionTreeClassificationModel...]</span>
<span class="sd"> &gt;&gt;&gt; rfc_path = temp_path + &quot;/rfc&quot;</span>
<span class="sd"> &gt;&gt;&gt; rf.save(rfc_path)</span>
<span class="sd"> &gt;&gt;&gt; rf2 = RandomForestClassifier.load(rfc_path)</span>
<span class="sd"> &gt;&gt;&gt; rf2.getNumTrees()</span>
<span class="sd"> 3</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/rfc_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = RandomForestClassificationModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances == model2.featureImportances</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bootstrap</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity=&quot;gini&quot;, \</span>
<span class="sd"> numTrees=20, featureSubsetStrategy=&quot;auto&quot;, seed=None, subsamplingRate=1.0, \</span>
<span class="sd"> leafCol=&quot;&quot;, minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.RandomForestClassifier&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="RandomForestClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">bootstrap</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">bool</span><span class="p">]</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \</span>
<span class="sd"> impurity=&quot;gini&quot;, numTrees=20, featureSubsetStrategy=&quot;auto&quot;, subsamplingRate=1.0, \</span>
<span class="sd"> leafCol=&quot;&quot;, minWeightFractionPerNode=0.0, weightCol=None, bootstrap=True)</span>
<span class="sd"> Sets params for linear classification.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassificationModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">RandomForestClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="RandomForestClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxDepth`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxBins`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`impurity`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setNumTrees"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setNumTrees">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setNumTrees</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`numTrees`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">numTrees</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setBootstrap"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setBootstrap">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setBootstrap</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`bootstrap`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">bootstrap</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setSubsamplingRate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setSubsamplingRate">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSubsamplingRate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`subsamplingRate`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">subsamplingRate</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setFeatureSubsetStrategy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setFeatureSubsetStrategy">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFeatureSubsetStrategy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`featureSubsetStrategy`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setCheckpointInterval">[docs]</a> <span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="RandomForestClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassifier.html#pyspark.ml.classification.RandomForestClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="RandomForestClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationModel.html#pyspark.ml.classification.RandomForestClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">RandomForestClassificationModel</span><span class="p">(</span>
<span class="n">_TreeEnsembleModel</span><span class="p">,</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_RandomForestClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;RandomForestClassificationModel&quot;</span><span class="p">],</span>
<span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">&quot;RandomForestClassificationTrainingSummary&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by RandomForestClassifier.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Estimate of the importance of each feature.</span>
<span class="sd"> Each feature&#39;s importance is the average of its importance across all trees in the ensemble</span>
<span class="sd"> The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.</span>
<span class="sd"> (Hastie, Tibshirani, Friedman. &quot;The Elements of Statistical Learning, 2nd Edition.&quot; 2001.)</span>
<span class="sd"> and follows the implementation from scikit-learn.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> See Also</span>
<span class="sd"> --------</span>
<span class="sd"> DecisionTreeClassificationModel.featureImportances</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;featureImportances&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">trees</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">DecisionTreeClassificationModel</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Trees in this ensemble. Warning: These have null parent Estimators.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">[</span><span class="n">DecisionTreeClassificationModel</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;trees&quot;</span><span class="p">))]</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;RandomForestClassificationTrainingSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span>
<span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">BinaryRandomForestClassificationTrainingSummary</span><span class="p">(</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">RandomForestClassificationTrainingSummary</span><span class="p">(</span>
<span class="nb">super</span><span class="p">(</span><span class="n">RandomForestClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;No training summary available for this </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="p">)</span>
<div class="viewcode-block" id="RandomForestClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationModel.html#pyspark.ml.classification.RandomForestClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="s2">&quot;BinaryRandomForestClassificationSummary&quot;</span><span class="p">,</span> <span class="s2">&quot;RandomForestClassificationSummary&quot;</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Evaluates the model on a test dataset.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">java_rf_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluate&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">numClasses</span> <span class="o">&lt;=</span> <span class="mi">2</span><span class="p">:</span>
<span class="k">return</span> <span class="n">BinaryRandomForestClassificationSummary</span><span class="p">(</span><span class="n">java_rf_summary</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">RandomForestClassificationSummary</span><span class="p">(</span><span class="n">java_rf_summary</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="RandomForestClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationSummary.html#pyspark.ml.classification.RandomForestClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">RandomForestClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for RandomForestClassification Results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="RandomForestClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.RandomForestClassificationTrainingSummary.html#pyspark.ml.classification.RandomForestClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">RandomForestClassificationTrainingSummary</span><span class="p">(</span>
<span class="n">RandomForestClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for RandomForestClassificationTraining Training results.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="BinaryRandomForestClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryRandomForestClassificationSummary.html#pyspark.ml.classification.BinaryRandomForestClassificationSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">BinaryRandomForestClassificationSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> BinaryRandomForestClassification results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="BinaryRandomForestClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary.html#pyspark.ml.classification.BinaryRandomForestClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">BinaryRandomForestClassificationTrainingSummary</span><span class="p">(</span>
<span class="n">BinaryRandomForestClassificationSummary</span><span class="p">,</span> <span class="n">RandomForestClassificationTrainingSummary</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> BinaryRandomForestClassification training results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<span class="k">class</span> <span class="nc">_GBTClassifierParams</span><span class="p">(</span><span class="n">_GBTParams</span><span class="p">,</span> <span class="n">_HasVarianceImpurity</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`GBTClassifier` and :py:class:`GBTClassifierModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">supportedLossTypes</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;logistic&quot;</span><span class="p">]</span>
<span class="n">lossType</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;lossType&quot;</span><span class="p">,</span>
<span class="s2">&quot;Loss function which GBT tries to minimize (case-insensitive). &quot;</span>
<span class="o">+</span> <span class="s2">&quot;Supported options: &quot;</span>
<span class="o">+</span> <span class="s2">&quot;, &quot;</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">supportedLossTypes</span><span class="p">),</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_GBTClassifierParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span>
<span class="n">maxDepth</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
<span class="n">lossType</span><span class="o">=</span><span class="s2">&quot;logistic&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span>
<span class="n">stepSize</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span>
<span class="n">impurity</span><span class="o">=</span><span class="s2">&quot;variance&quot;</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="s2">&quot;all&quot;</span><span class="p">,</span>
<span class="n">validationTol</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
<span class="n">leafCol</span><span class="o">=</span><span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getLossType</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of lossType or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">lossType</span><span class="p">)</span>
<div class="viewcode-block" id="GBTClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">GBTClassifier</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;GBTClassificationModel&quot;</span><span class="p">],</span>
<span class="n">_GBTClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;GBTClassifier&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> `Gradient-Boosted Trees (GBTs) &lt;http://en.wikipedia.org/wiki/Gradient_boosting&gt;`_</span>
<span class="sd"> learning algorithm for classification.</span>
<span class="sd"> It supports binary labels, as well as both continuous and categorical features.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> Multiclass labels are not currently supported.</span>
<span class="sd"> The implementation is based upon: J.H. Friedman. &quot;Stochastic Gradient Boosting.&quot; 1999.</span>
<span class="sd"> Gradient Boosting vs. TreeBoost:</span>
<span class="sd"> - This implementation is for Stochastic Gradient Boosting, not for TreeBoost.</span>
<span class="sd"> - Both algorithms learn tree ensembles by minimizing loss functions.</span>
<span class="sd"> - TreeBoost (Friedman, 1999) additionally modifies the outputs at tree leaf nodes</span>
<span class="sd"> based on the loss function, whereas the original gradient boosting method does not.</span>
<span class="sd"> - We expect to implement TreeBoost in the future:</span>
<span class="sd"> `SPARK-4240 &lt;https://issues.apache.org/jira/browse/SPARK-4240&gt;`_</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from numpy import allclose</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.feature import StringIndexer</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (1.0, Vectors.dense(1.0)),</span>
<span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], [&quot;label&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; stringIndexer = StringIndexer(inputCol=&quot;label&quot;, outputCol=&quot;indexed&quot;)</span>
<span class="sd"> &gt;&gt;&gt; si_model = stringIndexer.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; td = si_model.transform(df)</span>
<span class="sd"> &gt;&gt;&gt; gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol=&quot;indexed&quot;, seed=42,</span>
<span class="sd"> ... leafCol=&quot;leafId&quot;)</span>
<span class="sd"> &gt;&gt;&gt; gbt.setMaxIter(5)</span>
<span class="sd"> GBTClassifier...</span>
<span class="sd"> &gt;&gt;&gt; gbt.setMinWeightFractionPerNode(0.049)</span>
<span class="sd"> GBTClassifier...</span>
<span class="sd"> &gt;&gt;&gt; gbt.getMaxIter()</span>
<span class="sd"> 5</span>
<span class="sd"> &gt;&gt;&gt; gbt.getFeatureSubsetStrategy()</span>
<span class="sd"> &#39;all&#39;</span>
<span class="sd"> &gt;&gt;&gt; model = gbt.fit(td)</span>
<span class="sd"> &gt;&gt;&gt; model.getLabelCol()</span>
<span class="sd"> &#39;indexed&#39;</span>
<span class="sd"> &gt;&gt;&gt; model.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> GBTClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.setThresholds([0.3, 0.7])</span>
<span class="sd"> GBTClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getThresholds()</span>
<span class="sd"> [0.3, 0.7]</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances</span>
<span class="sd"> SparseVector(1, {0: 1.0})</span>
<span class="sd"> &gt;&gt;&gt; allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; test0 = spark.createDataFrame([(Vectors.dense(-1.0),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(test0.head().features)</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([1.1697, -1.1697])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([0.9121, 0.0879])</span>
<span class="sd"> &gt;&gt;&gt; result = model.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; result.leafId</span>
<span class="sd"> DenseVector([0.0, 0.0, 0.0, 0.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; test1 = spark.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test1).head().prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.totalNumNodes</span>
<span class="sd"> 15</span>
<span class="sd"> &gt;&gt;&gt; print(model.toDebugString)</span>
<span class="sd"> GBTClassificationModel...numTrees=5...</span>
<span class="sd"> &gt;&gt;&gt; gbtc_path = temp_path + &quot;gbtc&quot;</span>
<span class="sd"> &gt;&gt;&gt; gbt.save(gbtc_path)</span>
<span class="sd"> &gt;&gt;&gt; gbt2 = GBTClassifier.load(gbtc_path)</span>
<span class="sd"> &gt;&gt;&gt; gbt2.getMaxDepth()</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;gbtc_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = GBTClassificationModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.featureImportances == model2.featureImportances</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.treeWeights == model2.treeWeights</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.trees</span>
<span class="sd"> [DecisionTreeRegressionModel...depth=..., DecisionTreeRegressionModel...]</span>
<span class="sd"> &gt;&gt;&gt; validation = spark.createDataFrame([(0.0, Vectors.dense(-1.0),)],</span>
<span class="sd"> ... [&quot;indexed&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.evaluateEachIteration(validation)</span>
<span class="sd"> [0.25..., 0.23..., 0.21..., 0.19..., 0.18...]</span>
<span class="sd"> &gt;&gt;&gt; model.numClasses</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; gbt = gbt.setValidationIndicatorCol(&quot;validationIndicator&quot;)</span>
<span class="sd"> &gt;&gt;&gt; gbt.getValidationIndicatorCol()</span>
<span class="sd"> &#39;validationIndicator&#39;</span>
<span class="sd"> &gt;&gt;&gt; gbt.getValidationTol()</span>
<span class="sd"> 0.01</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">lossType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;logistic&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;all&quot;</span><span class="p">,</span>
<span class="n">validationTol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span>
<span class="n">validationIndicatorCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \</span>
<span class="sd"> lossType=&quot;logistic&quot;, maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \</span>
<span class="sd"> impurity=&quot;variance&quot;, featureSubsetStrategy=&quot;all&quot;, validationTol=0.01, \</span>
<span class="sd"> validationIndicatorCol=None, leafCol=&quot;&quot;, minWeightFractionPerNode=0.0, \</span>
<span class="sd"> weightCol=None)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">GBTClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.GBTClassifier&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="GBTClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">maxMemoryInMB</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">256</span><span class="p">,</span>
<span class="n">cacheNodeIds</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">False</span><span class="p">,</span>
<span class="n">checkpointInterval</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">10</span><span class="p">,</span>
<span class="n">lossType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;logistic&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">20</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">subsamplingRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;all&quot;</span><span class="p">,</span>
<span class="n">validationTol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span>
<span class="n">validationIndicatorCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">leafCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="p">,</span>
<span class="n">minWeightFractionPerNode</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \</span>
<span class="sd"> maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \</span>
<span class="sd"> lossType=&quot;logistic&quot;, maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \</span>
<span class="sd"> impurity=&quot;variance&quot;, featureSubsetStrategy=&quot;all&quot;, validationTol=0.01, \</span>
<span class="sd"> validationIndicatorCol=None, leafCol=&quot;&quot;, minWeightFractionPerNode=0.0, \</span>
<span class="sd"> weightCol=None)</span>
<span class="sd"> Sets params for Gradient Boosted Tree Classification.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassificationModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">GBTClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="GBTClassifier.setMaxDepth"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxDepth">[docs]</a> <span class="k">def</span> <span class="nf">setMaxDepth</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxDepth`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxDepth</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMaxBins"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxBins">[docs]</a> <span class="k">def</span> <span class="nf">setMaxBins</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxBins`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxBins</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMinInstancesPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinInstancesPerNode">[docs]</a> <span class="k">def</span> <span class="nf">setMinInstancesPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInstancesPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInstancesPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMinInfoGain"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinInfoGain">[docs]</a> <span class="k">def</span> <span class="nf">setMinInfoGain</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minInfoGain`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minInfoGain</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMaxMemoryInMB"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxMemoryInMB">[docs]</a> <span class="k">def</span> <span class="nf">setMaxMemoryInMB</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxMemoryInMB`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxMemoryInMB</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setCacheNodeIds"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setCacheNodeIds">[docs]</a> <span class="k">def</span> <span class="nf">setCacheNodeIds</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`cacheNodeIds`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">cacheNodeIds</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setImpurity"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setImpurity">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setImpurity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`impurity`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">impurity</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setLossType"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setLossType">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setLossType</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`lossType`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">lossType</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setSubsamplingRate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setSubsamplingRate">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSubsamplingRate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`subsamplingRate`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">subsamplingRate</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setFeatureSubsetStrategy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setFeatureSubsetStrategy">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFeatureSubsetStrategy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`featureSubsetStrategy`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featureSubsetStrategy</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setValidationIndicatorCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setValidationIndicatorCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setValidationIndicatorCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`validationIndicatorCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">validationIndicatorCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxIter`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setCheckpointInterval"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setCheckpointInterval">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setCheckpointInterval</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`checkpointInterval`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">checkpointInterval</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setSeed">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.4.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`stepSize`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setWeightCol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="GBTClassifier.setMinWeightFractionPerNode"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassifier.html#pyspark.ml.classification.GBTClassifier.setMinWeightFractionPerNode">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMinWeightFractionPerNode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;GBTClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`minWeightFractionPerNode`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">minWeightFractionPerNode</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="GBTClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassificationModel.html#pyspark.ml.classification.GBTClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">GBTClassificationModel</span><span class="p">(</span>
<span class="n">_TreeEnsembleModel</span><span class="p">,</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_GBTClassifierParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;GBTClassificationModel&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by GBTClassifier.</span>
<span class="sd"> .. versionadded:: 1.4.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="k">def</span> <span class="nf">featureImportances</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Estimate of the importance of each feature.</span>
<span class="sd"> Each feature&#39;s importance is the average of its importance across all trees in the ensemble</span>
<span class="sd"> The importance vector is normalized to sum to 1. This method is suggested by Hastie et al.</span>
<span class="sd"> (Hastie, Tibshirani, Friedman. &quot;The Elements of Statistical Learning, 2nd Edition.&quot; 2001.)</span>
<span class="sd"> and follows the implementation from scikit-learn.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> See Also</span>
<span class="sd"> --------</span>
<span class="sd"> DecisionTreeClassificationModel.featureImportances</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;featureImportances&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">trees</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="n">DecisionTreeRegressionModel</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Trees in this ensemble. Warning: These have null parent Estimators.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="p">[</span><span class="n">DecisionTreeRegressionModel</span><span class="p">(</span><span class="n">m</span><span class="p">)</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">list</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;trees&quot;</span><span class="p">))]</span>
<div class="viewcode-block" id="GBTClassificationModel.evaluateEachIteration"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.GBTClassificationModel.html#pyspark.ml.classification.GBTClassificationModel.evaluateEachIteration">[docs]</a> <span class="k">def</span> <span class="nf">evaluateEachIteration</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Method to compute error or loss for every iteration of gradient boosting.</span>
<span class="sd"> .. versionadded:: 2.4.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluateEachIteration&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span></div></div>
<span class="k">class</span> <span class="nc">_NaiveBayesParams</span><span class="p">(</span><span class="n">_PredictorParams</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`NaiveBayes` and :py:class:`NaiveBayesModel`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">smoothing</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;smoothing&quot;</span><span class="p">,</span>
<span class="s2">&quot;The smoothing parameter, should be &gt;= 0, &quot;</span> <span class="o">+</span> <span class="s2">&quot;default is 1.0&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toFloat</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">modelType</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;modelType&quot;</span><span class="p">,</span>
<span class="s2">&quot;The model type which is a string &quot;</span>
<span class="o">+</span> <span class="s2">&quot;(case-sensitive). Supported options: multinomial (default), bernoulli &quot;</span>
<span class="o">+</span> <span class="s2">&quot;and gaussian.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_NaiveBayesParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">modelType</span><span class="o">=</span><span class="s2">&quot;multinomial&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getSmoothing</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of smoothing or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">smoothing</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getModelType</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of modelType or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">modelType</span><span class="p">)</span>
<div class="viewcode-block" id="NaiveBayes"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">NaiveBayes</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;NaiveBayesModel&quot;</span><span class="p">],</span>
<span class="n">_NaiveBayesParams</span><span class="p">,</span>
<span class="n">HasThresholds</span><span class="p">,</span>
<span class="n">HasWeightCol</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;NaiveBayes&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Naive Bayes Classifiers.</span>
<span class="sd"> It supports both Multinomial and Bernoulli NB. `Multinomial NB \</span>
<span class="sd"> &lt;http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html&gt;`_</span>
<span class="sd"> can handle finitely supported discrete data. For example, by converting documents into</span>
<span class="sd"> TF-IDF vectors, it can be used for document classification. By making every vector a</span>
<span class="sd"> binary (0/1) data, it can also be used as `Bernoulli NB \</span>
<span class="sd"> &lt;http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html&gt;`_.</span>
<span class="sd"> The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.</span>
<span class="sd"> Since 3.0.0, it supports Complement NB which is an adaptation of the Multinomial NB.</span>
<span class="sd"> Specifically, Complement NB uses statistics from the complement of each class to compute</span>
<span class="sd"> the model&#39;s coefficients. The inventors of Complement NB show empirically that the parameter</span>
<span class="sd"> estimates for CNB are more stable than those for Multinomial NB. Like Multinomial NB, the</span>
<span class="sd"> input feature values for Complement NB must be nonnegative.</span>
<span class="sd"> Since 3.0.0, it also supports `Gaussian NB \</span>
<span class="sd"> &lt;https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes&gt;`_.</span>
<span class="sd"> which can handle continuous data.</span>
<span class="sd"> .. versionadded:: 1.5.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql import Row</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),</span>
<span class="sd"> ... Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),</span>
<span class="sd"> ... Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0]))])</span>
<span class="sd"> &gt;&gt;&gt; nb = NaiveBayes(smoothing=1.0, modelType=&quot;multinomial&quot;, weightCol=&quot;weight&quot;)</span>
<span class="sd"> &gt;&gt;&gt; model = nb.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> NaiveBayesModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getSmoothing()</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.pi</span>
<span class="sd"> DenseVector([-0.81..., -0.58...])</span>
<span class="sd"> &gt;&gt;&gt; model.theta</span>
<span class="sd"> DenseMatrix(2, 2, [-0.91..., -0.51..., -0.40..., -1.09...], 1)</span>
<span class="sd"> &gt;&gt;&gt; model.sigma</span>
<span class="sd"> DenseMatrix(0, 0, [...], ...)</span>
<span class="sd"> &gt;&gt;&gt; test0 = sc.parallelize([Row(features=Vectors.dense([1.0, 0.0]))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.predict(test0.head().features)</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([-1.72..., -0.99...])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([0.32..., 0.67...])</span>
<span class="sd"> &gt;&gt;&gt; result = model.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; result.probability</span>
<span class="sd"> DenseVector([0.32..., 0.67...])</span>
<span class="sd"> &gt;&gt;&gt; result.rawPrediction</span>
<span class="sd"> DenseVector([-1.72..., -0.99...])</span>
<span class="sd"> &gt;&gt;&gt; test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test1).head().prediction</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; nb_path = temp_path + &quot;/nb&quot;</span>
<span class="sd"> &gt;&gt;&gt; nb.save(nb_path)</span>
<span class="sd"> &gt;&gt;&gt; nb2 = NaiveBayes.load(nb_path)</span>
<span class="sd"> &gt;&gt;&gt; nb2.getSmoothing()</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/nb_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = NaiveBayesModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.pi == model2.pi</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.theta == model2.theta</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; nb = nb.setThresholds([0.01, 10.00])</span>
<span class="sd"> &gt;&gt;&gt; model3 = nb.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; result = model3.transform(test0).head()</span>
<span class="sd"> &gt;&gt;&gt; result.prediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; nb3 = NaiveBayes().setModelType(&quot;gaussian&quot;)</span>
<span class="sd"> &gt;&gt;&gt; model4 = nb3.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model4.getModelType()</span>
<span class="sd"> &#39;gaussian&#39;</span>
<span class="sd"> &gt;&gt;&gt; model4.sigma</span>
<span class="sd"> DenseMatrix(2, 2, [0.0, 0.25, 0.0, 0.0], 1)</span>
<span class="sd"> &gt;&gt;&gt; nb5 = NaiveBayes(smoothing=1.0, modelType=&quot;complement&quot;, weightCol=&quot;weight&quot;)</span>
<span class="sd"> &gt;&gt;&gt; model5 = nb5.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model5.getModelType()</span>
<span class="sd"> &#39;complement&#39;</span>
<span class="sd"> &gt;&gt;&gt; model5.theta</span>
<span class="sd"> DenseMatrix(2, 2, [...], 1)</span>
<span class="sd"> &gt;&gt;&gt; model5.sigma</span>
<span class="sd"> DenseMatrix(0, 0, [...], ...)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">modelType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;multinomial&quot;</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, smoothing=1.0, \</span>
<span class="sd"> modelType=&quot;multinomial&quot;, thresholds=None, weightCol=None)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">NaiveBayes</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.NaiveBayes&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="NaiveBayes.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">smoothing</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">modelType</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;multinomial&quot;</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;NaiveBayes&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, smoothing=1.0, \</span>
<span class="sd"> modelType=&quot;multinomial&quot;, thresholds=None, weightCol=None)</span>
<span class="sd"> Sets params for Naive Bayes.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;NaiveBayesModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">NaiveBayesModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="NaiveBayes.setSmoothing"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setSmoothing">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSmoothing</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;NaiveBayes&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`smoothing`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="NaiveBayes.setModelType"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setModelType">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.5.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setModelType</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;NaiveBayes&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`modelType`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">modelType</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="NaiveBayes.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayes.html#pyspark.ml.classification.NaiveBayes.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;NaiveBayes&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="NaiveBayesModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.NaiveBayesModel.html#pyspark.ml.classification.NaiveBayesModel">[docs]</a><span class="k">class</span> <span class="nc">NaiveBayesModel</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_NaiveBayesParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;NaiveBayesModel&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by NaiveBayes.</span>
<span class="sd"> .. versionadded:: 1.5.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">pi</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> log of class priors.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;pi&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">theta</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> log of class conditional probabilities.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;theta&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">sigma</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> variance of each feature.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;sigma&quot;</span><span class="p">)</span></div>
<span class="k">class</span> <span class="nc">_MultilayerPerceptronParams</span><span class="p">(</span>
<span class="n">_ProbabilisticClassifierParams</span><span class="p">,</span>
<span class="n">HasSeed</span><span class="p">,</span>
<span class="n">HasMaxIter</span><span class="p">,</span>
<span class="n">HasTol</span><span class="p">,</span>
<span class="n">HasStepSize</span><span class="p">,</span>
<span class="n">HasSolver</span><span class="p">,</span>
<span class="n">HasBlockSize</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`MultilayerPerceptronClassifier`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">layers</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;layers&quot;</span><span class="p">,</span>
<span class="s2">&quot;Sizes of layers from input layer to output layer &quot;</span>
<span class="o">+</span> <span class="s2">&quot;E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 &quot;</span>
<span class="o">+</span> <span class="s2">&quot;neurons and output layer of 10 neurons.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toListInt</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">solver</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;solver&quot;</span><span class="p">,</span>
<span class="s2">&quot;The solver algorithm for optimization. Supported &quot;</span> <span class="o">+</span> <span class="s2">&quot;options: l-bfgs, gd.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toString</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">initialWeights</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span>
<span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span>
<span class="s2">&quot;initialWeights&quot;</span><span class="p">,</span>
<span class="s2">&quot;The initial weights of the model.&quot;</span><span class="p">,</span>
<span class="n">typeConverter</span><span class="o">=</span><span class="n">TypeConverters</span><span class="o">.</span><span class="n">toVector</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">:</span> <span class="n">Any</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">_MultilayerPerceptronParams</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">stepSize</span><span class="o">=</span><span class="mf">0.03</span><span class="p">,</span> <span class="n">solver</span><span class="o">=</span><span class="s2">&quot;l-bfgs&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.6.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getLayers</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of layers or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getInitialWeights</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of initialWeights or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initialWeights</span><span class="p">)</span>
<div class="viewcode-block" id="MultilayerPerceptronClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;MultilayerPerceptronClassificationModel&quot;</span><span class="p">],</span>
<span class="n">_MultilayerPerceptronParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Classifier trainer based on the Multilayer Perceptron.</span>
<span class="sd"> Each layer has sigmoid activation function, output layer has softmax.</span>
<span class="sd"> Number of inputs has to be equal to the size of feature vectors.</span>
<span class="sd"> Number of outputs has to be equal to the total number of labels.</span>
<span class="sd"> .. versionadded:: 1.6.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (0.0, Vectors.dense([0.0, 0.0])),</span>
<span class="sd"> ... (1.0, Vectors.dense([0.0, 1.0])),</span>
<span class="sd"> ... (1.0, Vectors.dense([1.0, 0.0])),</span>
<span class="sd"> ... (0.0, Vectors.dense([1.0, 1.0]))], [&quot;label&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; mlp = MultilayerPerceptronClassifier(layers=[2, 2, 2], seed=123)</span>
<span class="sd"> &gt;&gt;&gt; mlp.setMaxIter(100)</span>
<span class="sd"> MultilayerPerceptronClassifier...</span>
<span class="sd"> &gt;&gt;&gt; mlp.getMaxIter()</span>
<span class="sd"> 100</span>
<span class="sd"> &gt;&gt;&gt; mlp.getBlockSize()</span>
<span class="sd"> 128</span>
<span class="sd"> &gt;&gt;&gt; mlp.setBlockSize(1)</span>
<span class="sd"> MultilayerPerceptronClassifier...</span>
<span class="sd"> &gt;&gt;&gt; mlp.getBlockSize()</span>
<span class="sd"> 1</span>
<span class="sd"> &gt;&gt;&gt; model = mlp.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model.setFeaturesCol(&quot;features&quot;)</span>
<span class="sd"> MultilayerPerceptronClassificationModel...</span>
<span class="sd"> &gt;&gt;&gt; model.getMaxIter()</span>
<span class="sd"> 100</span>
<span class="sd"> &gt;&gt;&gt; model.getLayers()</span>
<span class="sd"> [2, 2, 2]</span>
<span class="sd"> &gt;&gt;&gt; model.weights.size</span>
<span class="sd"> 12</span>
<span class="sd"> &gt;&gt;&gt; testDF = spark.createDataFrame([</span>
<span class="sd"> ... (Vectors.dense([1.0, 0.0]),),</span>
<span class="sd"> ... (Vectors.dense([0.0, 0.0]),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(testDF.head().features)</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(testDF.head().features)</span>
<span class="sd"> DenseVector([-16.208, 16.344])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(testDF.head().features)</span>
<span class="sd"> DenseVector([0.0, 1.0])</span>
<span class="sd"> &gt;&gt;&gt; model.transform(testDF).select(&quot;features&quot;, &quot;prediction&quot;).show()</span>
<span class="sd"> +---------+----------+</span>
<span class="sd"> | features|prediction|</span>
<span class="sd"> +---------+----------+</span>
<span class="sd"> |[1.0,0.0]| 1.0|</span>
<span class="sd"> |[0.0,0.0]| 0.0|</span>
<span class="sd"> +---------+----------+</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; mlp_path = temp_path + &quot;/mlp&quot;</span>
<span class="sd"> &gt;&gt;&gt; mlp.save(mlp_path)</span>
<span class="sd"> &gt;&gt;&gt; mlp2 = MultilayerPerceptronClassifier.load(mlp_path)</span>
<span class="sd"> &gt;&gt;&gt; mlp2.getBlockSize()</span>
<span class="sd"> 1</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/mlp_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = MultilayerPerceptronClassificationModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model.getLayers() == model2.getLayers()</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.weights == model2.weights</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(testDF).take(1) == model2.transform(testDF).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; mlp2 = mlp2.setInitialWeights(list(range(0, 12)))</span>
<span class="sd"> &gt;&gt;&gt; model3 = mlp2.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model3.weights != model2.weights</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model3.getLayers() == model.getLayers()</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">blockSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.03</span><span class="p">,</span>
<span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;l-bfgs&quot;</span><span class="p">,</span>
<span class="n">initialWeights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \</span>
<span class="sd"> solver=&quot;l-bfgs&quot;, initialWeights=None, probabilityCol=&quot;probability&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">MultilayerPerceptronClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.MultilayerPerceptronClassifier&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.6.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">layers</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">blockSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">128</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.03</span><span class="p">,</span>
<span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;l-bfgs&quot;</span><span class="p">,</span>
<span class="n">initialWeights</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Vector</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> maxIter=100, tol=1e-6, seed=None, layers=None, blockSize=128, stepSize=0.03, \</span>
<span class="sd"> solver=&quot;l-bfgs&quot;, initialWeights=None, probabilityCol=&quot;probability&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;):</span>
<span class="sd"> Sets params for MultilayerPerceptronClassifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassificationModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">MultilayerPerceptronClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setLayers"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setLayers">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.6.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setLayers</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`layers`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">layers</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setBlockSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setBlockSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.6.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setBlockSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`blockSize`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">blockSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setInitialWeights"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setInitialWeights">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setInitialWeights</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Vector</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`initialWeights`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">initialWeights</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setMaxIter">[docs]</a> <span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxIter`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setSeed">[docs]</a> <span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setTol">[docs]</a> <span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`tol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`stepSize`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassifier.setSolver"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassifier.html#pyspark.ml.classification.MultilayerPerceptronClassifier.setSolver">[docs]</a> <span class="k">def</span> <span class="nf">setSolver</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`solver`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">solver</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="MultilayerPerceptronClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationModel</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_MultilayerPerceptronParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;MultilayerPerceptronClassificationModel&quot;</span><span class="p">],</span>
<span class="n">HasTrainingSummary</span><span class="p">[</span><span class="s2">&quot;MultilayerPerceptronClassificationTrainingSummary&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by MultilayerPerceptronClassifier.</span>
<span class="sd"> .. versionadded:: 1.6.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">weights</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> the weights of layers.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;weights&quot;</span><span class="p">)</span>
<div class="viewcode-block" id="MultilayerPerceptronClassificationModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">summary</span><span class="p">(</span> <span class="c1"># type: ignore[override]</span>
<span class="bp">self</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassificationTrainingSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span>
<span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span>
<span class="k">return</span> <span class="n">MultilayerPerceptronClassificationTrainingSummary</span><span class="p">(</span>
<span class="nb">super</span><span class="p">(</span><span class="n">MultilayerPerceptronClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span>
<span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;No training summary available for this </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationModel.html#pyspark.ml.classification.MultilayerPerceptronClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;MultilayerPerceptronClassificationSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Evaluates the model on a test dataset.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">java_mlp_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluate&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="k">return</span> <span class="n">MultilayerPerceptronClassificationSummary</span><span class="p">(</span><span class="n">java_mlp_summary</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="MultilayerPerceptronClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationSummary.html#pyspark.ml.classification.MultilayerPerceptronClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationSummary</span><span class="p">(</span><span class="n">_ClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for MultilayerPerceptronClassifier Results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="MultilayerPerceptronClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary.html#pyspark.ml.classification.MultilayerPerceptronClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">MultilayerPerceptronClassificationTrainingSummary</span><span class="p">(</span>
<span class="n">MultilayerPerceptronClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for MultilayerPerceptronClassifier Training results.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<span class="k">class</span> <span class="nc">_OneVsRestParams</span><span class="p">(</span><span class="n">_ClassifierParams</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Params for :py:class:`OneVsRest` and :py:class:`OneVsRestModelModel`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">classifier</span><span class="p">:</span> <span class="n">Param</span><span class="p">[</span><span class="n">Classifier</span><span class="p">]</span> <span class="o">=</span> <span class="n">Param</span><span class="p">(</span><span class="n">Params</span><span class="o">.</span><span class="n">_dummy</span><span class="p">(),</span> <span class="s2">&quot;classifier&quot;</span><span class="p">,</span> <span class="s2">&quot;base binary classifier&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">getClassifier</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Classifier</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets the value of classifier or its default value.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">getOrDefault</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">)</span>
<div class="viewcode-block" id="OneVsRest"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">OneVsRest</span><span class="p">(</span>
<span class="n">Estimator</span><span class="p">[</span><span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">],</span>
<span class="n">_OneVsRestParams</span><span class="p">,</span>
<span class="n">HasParallelism</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;OneVsRest&quot;</span><span class="p">],</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="n">Generic</span><span class="p">[</span><span class="n">CM</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Reduction of Multiclass Classification to Binary Classification.</span>
<span class="sd"> Performs reduction using one against all strategy.</span>
<span class="sd"> For a multiclass classification with k classes, train k models (one per class).</span>
<span class="sd"> Each example is scored against all k models and the model with highest score</span>
<span class="sd"> is picked to label the example.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.sql import Row</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; data_path = &quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span>
<span class="sd"> &gt;&gt;&gt; df = spark.read.format(&quot;libsvm&quot;).load(data_path)</span>
<span class="sd"> &gt;&gt;&gt; lr = LogisticRegression(regParam=0.01)</span>
<span class="sd"> &gt;&gt;&gt; ovr = OneVsRest(classifier=lr)</span>
<span class="sd"> &gt;&gt;&gt; ovr.getRawPredictionCol()</span>
<span class="sd"> &#39;rawPrediction&#39;</span>
<span class="sd"> &gt;&gt;&gt; ovr.setPredictionCol(&quot;newPrediction&quot;)</span>
<span class="sd"> OneVsRest...</span>
<span class="sd"> &gt;&gt;&gt; model = ovr.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model.models[0].coefficients</span>
<span class="sd"> DenseVector([0.5..., -1.0..., 3.4..., 4.2...])</span>
<span class="sd"> &gt;&gt;&gt; model.models[1].coefficients</span>
<span class="sd"> DenseVector([-2.1..., 3.1..., -2.6..., -2.3...])</span>
<span class="sd"> &gt;&gt;&gt; model.models[2].coefficients</span>
<span class="sd"> DenseVector([0.3..., -3.4..., 1.0..., -1.1...])</span>
<span class="sd"> &gt;&gt;&gt; [x.intercept for x in model.models]</span>
<span class="sd"> [-2.7..., -2.5..., -1.3...]</span>
<span class="sd"> &gt;&gt;&gt; test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).head().newPrediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test1).head().newPrediction</span>
<span class="sd"> 2.0</span>
<span class="sd"> &gt;&gt;&gt; test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF()</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test2).head().newPrediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/ovr_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = OneVsRestModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2.transform(test0).head().newPrediction</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test2).columns</span>
<span class="sd"> [&#39;features&#39;, &#39;rawPrediction&#39;, &#39;newPrediction&#39;]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">classifier</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;, classifier=None, weightCol=None, parallelism=1):</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_setDefault</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="OneVsRest.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">classifier</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">weightCol</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">str</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">parallelism</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> rawPredictionCol=&quot;rawPrediction&quot;, classifier=None, weightCol=None, parallelism=1):</span>
<span class="sd"> Sets params for OneVsRest.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setClassifier">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;2.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setClassifier</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="n">Classifier</span><span class="p">[</span><span class="n">CM</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`classifier`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setLabelCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setLabelCol">[docs]</a> <span class="k">def</span> <span class="nf">setLabelCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`labelCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setFeaturesCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setFeaturesCol">[docs]</a> <span class="k">def</span> <span class="nf">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`featuresCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`predictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">predictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setRawPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setRawPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setWeightCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setWeightCol">[docs]</a> <span class="k">def</span> <span class="nf">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`weightCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.setParallelism"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.setParallelism">[docs]</a> <span class="k">def</span> <span class="nf">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`parallelism`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">parallelism</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_fit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="n">labelCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span>
<span class="n">featuresCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span>
<span class="n">predictionCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span>
<span class="n">classifier</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()</span>
<span class="n">numClasses</span> <span class="o">=</span> <span class="p">(</span>
<span class="nb">int</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">Row</span><span class="p">,</span> <span class="n">dataset</span><span class="o">.</span><span class="n">agg</span><span class="p">({</span><span class="n">labelCol</span><span class="p">:</span> <span class="s2">&quot;max&quot;</span><span class="p">})</span><span class="o">.</span><span class="n">head</span><span class="p">())[</span><span class="s2">&quot;max(&quot;</span> <span class="o">+</span> <span class="n">labelCol</span> <span class="o">+</span> <span class="s2">&quot;)&quot;</span><span class="p">])</span> <span class="o">+</span> <span class="mi">1</span>
<span class="p">)</span>
<span class="n">weightCol</span> <span class="o">=</span> <span class="kc">None</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">classifier</span><span class="p">,</span> <span class="n">HasWeightCol</span><span class="p">):</span>
<span class="n">weightCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">()</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">warnings</span><span class="o">.</span><span class="n">warn</span><span class="p">(</span>
<span class="s2">&quot;weightCol is ignored, &quot;</span> <span class="s2">&quot;as it is not supported by </span><span class="si">{}</span><span class="s2"> now.&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">classifier</span><span class="p">)</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">weightCol</span><span class="p">:</span>
<span class="n">multiclassLabeled</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">,</span> <span class="n">weightCol</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">multiclassLabeled</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">)</span>
<span class="c1"># persist if underlying dataset is not persistent.</span>
<span class="n">handlePersistence</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">storageLevel</span> <span class="o">==</span> <span class="n">StorageLevel</span><span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span>
<span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">persist</span><span class="p">(</span><span class="n">StorageLevel</span><span class="o">.</span><span class="n">MEMORY_AND_DISK</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">trainSingleClass</span><span class="p">(</span><span class="n">index</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">CM</span><span class="p">:</span>
<span class="n">binaryLabelCol</span> <span class="o">=</span> <span class="s2">&quot;mc2b$&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">index</span><span class="p">)</span>
<span class="n">trainingDataset</span> <span class="o">=</span> <span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span>
<span class="n">binaryLabelCol</span><span class="p">,</span>
<span class="n">when</span><span class="p">(</span><span class="n">multiclassLabeled</span><span class="p">[</span><span class="n">labelCol</span><span class="p">]</span> <span class="o">==</span> <span class="nb">float</span><span class="p">(</span><span class="n">index</span><span class="p">),</span> <span class="mf">1.0</span><span class="p">)</span><span class="o">.</span><span class="n">otherwise</span><span class="p">(</span><span class="mf">0.0</span><span class="p">),</span>
<span class="p">)</span>
<span class="n">paramMap</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span>
<span class="p">[</span>
<span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">labelCol</span><span class="p">,</span> <span class="n">binaryLabelCol</span><span class="p">),</span>
<span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">featuresCol</span><span class="p">,</span> <span class="n">featuresCol</span><span class="p">),</span>
<span class="p">(</span><span class="n">classifier</span><span class="o">.</span><span class="n">predictionCol</span><span class="p">,</span> <span class="n">predictionCol</span><span class="p">),</span>
<span class="p">]</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">weightCol</span><span class="p">:</span>
<span class="n">paramMap</span><span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">HasWeightCol</span><span class="p">,</span> <span class="n">classifier</span><span class="p">)</span><span class="o">.</span><span class="n">weightCol</span><span class="p">]</span> <span class="o">=</span> <span class="n">weightCol</span>
<span class="k">return</span> <span class="n">classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingDataset</span><span class="p">,</span> <span class="n">paramMap</span><span class="p">)</span>
<span class="n">pool</span> <span class="o">=</span> <span class="n">ThreadPool</span><span class="p">(</span><span class="n">processes</span><span class="o">=</span><span class="nb">min</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">(),</span> <span class="n">numClasses</span><span class="p">))</span>
<span class="n">models</span> <span class="o">=</span> <span class="n">pool</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">inheritable_thread_target</span><span class="p">(</span><span class="n">trainSingleClass</span><span class="p">),</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">))</span>
<span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span>
<span class="n">multiclassLabeled</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_copyValues</span><span class="p">(</span><span class="n">OneVsRestModel</span><span class="p">(</span><span class="n">models</span><span class="o">=</span><span class="n">models</span><span class="p">))</span>
<div class="viewcode-block" id="OneVsRest.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This creates a deep copy of the embedded paramMap,</span>
<span class="sd"> and copies the embedded and extra parameters over.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`OneVsRest`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">newOvr</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isSet</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">):</span>
<span class="n">newOvr</span><span class="o">.</span><span class="n">setClassifier</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">))</span>
<span class="k">return</span> <span class="n">newOvr</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRest&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java OneVsRest, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">featuresCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span>
<span class="n">labelCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span>
<span class="n">predictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span>
<span class="n">rawPredictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">()</span>
<span class="n">classifier</span><span class="p">:</span> <span class="n">Classifier</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span>
<span class="n">parallelism</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">()</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span>
<span class="n">featuresCol</span><span class="o">=</span><span class="n">featuresCol</span><span class="p">,</span>
<span class="n">labelCol</span><span class="o">=</span><span class="n">labelCol</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="o">=</span><span class="n">predictionCol</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">rawPredictionCol</span><span class="p">,</span>
<span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">,</span>
<span class="n">parallelism</span><span class="o">=</span><span class="n">parallelism</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="s2">&quot;weightCol&quot;</span><span class="p">)):</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">setWeightCol</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java OneVsRest. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.OneVsRest&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setClassifier</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setParallelism</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getParallelism</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setLabelCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setWeightCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">())</span>
<span class="k">return</span> <span class="n">_java_obj</span>
<div class="viewcode-block" id="OneVsRest.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestReader&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">OneVsRestReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRest.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRest.html#pyspark.ml.classification.OneVsRest.write">[docs]</a> <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">(),</span> <span class="n">JavaMLWritable</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">OneVsRestWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div></div>
<span class="k">class</span> <span class="nc">_OneVsRestSharedReadWrite</span><span class="p">:</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span>
<span class="n">instance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">],</span>
<span class="n">sc</span><span class="p">:</span> <span class="s2">&quot;SparkContext&quot;</span><span class="p">,</span>
<span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">extraMetadata</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">skipParams</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;classifier&quot;</span><span class="p">]</span>
<span class="n">jsonParams</span> <span class="o">=</span> <span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">extractJsonParams</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">skipParams</span><span class="p">)</span>
<span class="n">DefaultParamsWriter</span><span class="o">.</span><span class="n">saveMetadata</span><span class="p">(</span>
<span class="n">instance</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">sc</span><span class="p">,</span> <span class="n">paramMap</span><span class="o">=</span><span class="n">jsonParams</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span>
<span class="p">)</span>
<span class="n">classifierPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;classifier&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">classifierPath</span><span class="p">)</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span> <span class="n">sc</span><span class="p">:</span> <span class="s2">&quot;SparkContext&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">]:</span>
<span class="n">classifierPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="s2">&quot;classifier&quot;</span><span class="p">)</span>
<span class="k">return</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">classifierPath</span><span class="p">,</span> <span class="n">sc</span><span class="p">)</span>
<span class="nd">@staticmethod</span>
<span class="k">def</span> <span class="nf">validateParams</span><span class="p">(</span><span class="n">instance</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">,</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">elems_to_check</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">Params</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">instance</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()]</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">OneVsRestModel</span><span class="p">):</span>
<span class="n">elems_to_check</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">)</span>
<span class="k">for</span> <span class="n">elem</span> <span class="ow">in</span> <span class="n">elems_to_check</span><span class="p">:</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">MLWritable</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span>
<span class="sa">f</span><span class="s2">&quot;OneVsRest write will fail because it contains </span><span class="si">{</span><span class="n">elem</span><span class="o">.</span><span class="n">uid</span><span class="si">}</span><span class="s2"> &quot;</span>
<span class="sa">f</span><span class="s2">&quot;which is not writable.&quot;</span>
<span class="p">)</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">OneVsRestReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">OneVsRest</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">OneVsRest</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">classifier</span> <span class="o">=</span> <span class="n">cast</span><span class="p">(</span><span class="n">Classifier</span><span class="p">,</span> <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">))</span>
<span class="n">ova</span><span class="p">:</span> <span class="n">OneVsRest</span> <span class="o">=</span> <span class="n">OneVsRest</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">)</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">])</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">ova</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;classifier&quot;</span><span class="p">])</span>
<span class="k">return</span> <span class="n">ova</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">OneVsRestWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="n">OneVsRest</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span>
<div class="viewcode-block" id="OneVsRestModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel">[docs]</a><span class="k">class</span> <span class="nc">OneVsRestModel</span><span class="p">(</span>
<span class="n">Model</span><span class="p">,</span>
<span class="n">_OneVsRestParams</span><span class="p">,</span>
<span class="n">MLReadable</span><span class="p">[</span><span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">],</span>
<span class="n">MLWritable</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by OneVsRest.</span>
<span class="sd"> This stores the models resulting from training k binary classifiers: one for each class.</span>
<span class="sd"> Each example is scored against all k models, and the model with the highest score</span>
<span class="sd"> is picked to label the example.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<div class="viewcode-block" id="OneVsRestModel.setFeaturesCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setFeaturesCol">[docs]</a> <span class="k">def</span> <span class="nf">setFeaturesCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`featuresCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRestModel.setPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`predictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">predictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRestModel.setRawPredictionCol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.setRawPredictionCol">[docs]</a> <span class="k">def</span> <span class="nf">setRawPredictionCol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`rawPredictionCol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">rawPredictionCol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">models</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="kn">from</span> <span class="nn">pyspark.core.context</span> <span class="kn">import</span> <span class="n">SparkContext</span>
<span class="bp">self</span><span class="o">.</span><span class="n">models</span> <span class="o">=</span> <span class="n">models</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">models</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">JavaMLWritable</span><span class="p">):</span>
<span class="k">return</span>
<span class="c1"># set java instance</span>
<span class="n">java_models</span> <span class="o">=</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassificationModel</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">java_models_array</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_array</span><span class="p">(</span>
<span class="n">java_models</span><span class="p">,</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="n">ClassificationModel</span>
<span class="p">)</span>
<span class="c1"># TODO: need to set metadata</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">&quot;org.apache.spark.sql.types.Metadata&quot;</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.OneVsRestModel&quot;</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span>
<span class="n">metadata</span><span class="o">.</span><span class="n">empty</span><span class="p">(),</span>
<span class="n">java_models_array</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">def</span> <span class="nf">_transform</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DataFrame</span><span class="p">:</span>
<span class="c1"># determine the input columns: these need to be passed through</span>
<span class="n">origCols</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">columns</span>
<span class="c1"># add an accumulator column to store predictions of all the models</span>
<span class="n">accColName</span> <span class="o">=</span> <span class="s2">&quot;mbc$acc&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="n">initUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="p">[],</span> <span class="n">ArrayType</span><span class="p">(</span><span class="n">DoubleType</span><span class="p">()))</span>
<span class="n">newDataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span><span class="n">accColName</span><span class="p">,</span> <span class="n">initUDF</span><span class="p">(</span><span class="n">dataset</span><span class="p">[</span><span class="n">origCols</span><span class="p">[</span><span class="mi">0</span><span class="p">]]))</span>
<span class="c1"># persist if underlying dataset is not persistent.</span>
<span class="n">handlePersistence</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">storageLevel</span> <span class="o">==</span> <span class="n">StorageLevel</span><span class="p">(</span><span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span>
<span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span>
<span class="n">newDataset</span><span class="o">.</span><span class="n">persist</span><span class="p">(</span><span class="n">StorageLevel</span><span class="o">.</span><span class="n">MEMORY_AND_DISK</span><span class="p">)</span>
<span class="c1"># update the accumulator column with the result of prediction of models</span>
<span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">newDataset</span>
<span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">model</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">):</span>
<span class="n">rawPredictionCol</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">()</span>
<span class="n">columns</span> <span class="o">=</span> <span class="n">origCols</span> <span class="o">+</span> <span class="p">[</span><span class="n">rawPredictionCol</span><span class="p">,</span> <span class="n">accColName</span><span class="p">]</span>
<span class="c1"># add temporary column to store intermediate scores and update</span>
<span class="n">tmpColName</span> <span class="o">=</span> <span class="s2">&quot;mbc$tmp&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">uuid</span><span class="o">.</span><span class="n">uuid4</span><span class="p">())</span>
<span class="n">updateUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">predictions</span><span class="p">,</span> <span class="n">prediction</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">+</span> <span class="p">[</span><span class="n">prediction</span><span class="o">.</span><span class="n">tolist</span><span class="p">()[</span><span class="mi">1</span><span class="p">]],</span>
<span class="n">ArrayType</span><span class="p">(</span><span class="n">DoubleType</span><span class="p">()),</span>
<span class="p">)</span>
<span class="n">transformedDataset</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">)</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="o">*</span><span class="n">columns</span><span class="p">)</span>
<span class="n">updatedDataset</span> <span class="o">=</span> <span class="n">transformedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span>
<span class="n">tmpColName</span><span class="p">,</span>
<span class="n">updateUDF</span><span class="p">(</span><span class="n">transformedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">],</span> <span class="n">transformedDataset</span><span class="p">[</span><span class="n">rawPredictionCol</span><span class="p">]),</span>
<span class="p">)</span>
<span class="n">newColumns</span> <span class="o">=</span> <span class="n">origCols</span> <span class="o">+</span> <span class="p">[</span><span class="n">tmpColName</span><span class="p">]</span>
<span class="c1"># switch out the intermediate column with the accumulator column</span>
<span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">updatedDataset</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="o">*</span><span class="n">newColumns</span><span class="p">)</span><span class="o">.</span><span class="n">withColumnRenamed</span><span class="p">(</span>
<span class="n">tmpColName</span><span class="p">,</span> <span class="n">accColName</span>
<span class="p">)</span>
<span class="k">if</span> <span class="n">handlePersistence</span><span class="p">:</span>
<span class="n">newDataset</span><span class="o">.</span><span class="n">unpersist</span><span class="p">()</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">():</span>
<span class="k">def</span> <span class="nf">func</span><span class="p">(</span><span class="n">predictions</span><span class="p">:</span> <span class="n">Iterable</span><span class="p">[</span><span class="nb">float</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="n">predArray</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">predictions</span><span class="p">:</span>
<span class="n">predArray</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="k">return</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">predArray</span><span class="p">)</span>
<span class="n">rawPredictionUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">VectorUDT</span><span class="p">())</span>
<span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">getRawPredictionCol</span><span class="p">(),</span> <span class="n">rawPredictionUDF</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">])</span>
<span class="p">)</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">():</span>
<span class="c1"># output the index of the classifier with highest confidence as prediction</span>
<span class="n">labelUDF</span> <span class="o">=</span> <span class="n">udf</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">predictions</span><span class="p">:</span> <span class="nb">float</span><span class="p">(</span>
<span class="nb">max</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">predictions</span><span class="p">),</span> <span class="n">key</span><span class="o">=</span><span class="n">operator</span><span class="o">.</span><span class="n">itemgetter</span><span class="p">(</span><span class="mi">1</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span>
<span class="p">),</span>
<span class="n">DoubleType</span><span class="p">(),</span>
<span class="p">)</span>
<span class="n">aggregatedDataset</span> <span class="o">=</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">withColumn</span><span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">(),</span> <span class="n">labelUDF</span><span class="p">(</span><span class="n">aggregatedDataset</span><span class="p">[</span><span class="n">accColName</span><span class="p">])</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">aggregatedDataset</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">accColName</span><span class="p">)</span>
<div class="viewcode-block" id="OneVsRestModel.copy"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.copy">[docs]</a> <span class="k">def</span> <span class="nf">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="s2">&quot;ParamMap&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Creates a copy of this instance with a randomly generated uid</span>
<span class="sd"> and some extra params. This creates a deep copy of the embedded paramMap,</span>
<span class="sd"> and copies the embedded and extra parameters over.</span>
<span class="sd"> .. versionadded:: 2.0.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> extra : dict, optional</span>
<span class="sd"> Extra parameters to copy to the new instance</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`OneVsRestModel`</span>
<span class="sd"> Copy of this instance</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="n">extra</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">extra</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">()</span>
<span class="n">newModel</span> <span class="o">=</span> <span class="n">Params</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">extra</span><span class="p">)</span>
<span class="n">newModel</span><span class="o">.</span><span class="n">models</span> <span class="o">=</span> <span class="p">[</span><span class="n">model</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">extra</span><span class="p">)</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span>
<span class="k">return</span> <span class="n">newModel</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_from_java</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">java_stage</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModel&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Given a Java OneVsRestModel, create and return a Python wrapper of it.</span>
<span class="sd"> Used for ML persistence.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">featuresCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">()</span>
<span class="n">labelCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">()</span>
<span class="n">predictionCol</span> <span class="o">=</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">()</span>
<span class="n">classifier</span><span class="p">:</span> <span class="n">Classifier</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span>
<span class="n">models</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span>
<span class="n">JavaParams</span><span class="o">.</span><span class="n">_from_java</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">models</span><span class="p">()</span>
<span class="p">]</span>
<span class="n">py_stage</span> <span class="o">=</span> <span class="bp">cls</span><span class="p">(</span><span class="n">models</span><span class="o">=</span><span class="n">models</span><span class="p">)</span><span class="o">.</span><span class="n">setPredictionCol</span><span class="p">(</span><span class="n">predictionCol</span><span class="p">)</span><span class="o">.</span><span class="n">setFeaturesCol</span><span class="p">(</span><span class="n">featuresCol</span><span class="p">)</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="n">labelCol</span><span class="p">)</span>
<span class="k">if</span> <span class="n">java_stage</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getParam</span><span class="p">(</span><span class="s2">&quot;weightCol&quot;</span><span class="p">)):</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">weightCol</span><span class="o">=</span><span class="n">java_stage</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">classifier</span><span class="p">)</span>
<span class="n">py_stage</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span><span class="n">java_stage</span><span class="o">.</span><span class="n">uid</span><span class="p">())</span>
<span class="k">return</span> <span class="n">py_stage</span>
<span class="k">def</span> <span class="nf">_to_java</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transfer this instance to a Java OneVsRestModel. Used for ML persistence.</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> py4j.java_gateway.JavaObject</span>
<span class="sd"> Java object equivalent to this instance.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="kn">from</span> <span class="nn">pyspark.core.context</span> <span class="kn">import</span> <span class="n">SparkContext</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">SparkContext</span><span class="o">.</span><span class="n">_active_spark_context</span>
<span class="k">assert</span> <span class="n">sc</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="ow">and</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span>
<span class="n">java_models</span> <span class="o">=</span> <span class="p">[</span><span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassificationModel</span><span class="p">,</span> <span class="n">model</span><span class="p">)</span><span class="o">.</span><span class="n">_to_java</span><span class="p">()</span> <span class="k">for</span> <span class="n">model</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">]</span>
<span class="n">java_models_array</span> <span class="o">=</span> <span class="n">JavaWrapper</span><span class="o">.</span><span class="n">_new_java_array</span><span class="p">(</span>
<span class="n">java_models</span><span class="p">,</span> <span class="n">sc</span><span class="o">.</span><span class="n">_gateway</span><span class="o">.</span><span class="n">jvm</span><span class="o">.</span><span class="n">org</span><span class="o">.</span><span class="n">apache</span><span class="o">.</span><span class="n">spark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="n">ClassificationModel</span>
<span class="p">)</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span><span class="s2">&quot;org.apache.spark.sql.types.Metadata&quot;</span><span class="p">)</span>
<span class="n">_java_obj</span> <span class="o">=</span> <span class="n">JavaParams</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.OneVsRestModel&quot;</span><span class="p">,</span>
<span class="bp">self</span><span class="o">.</span><span class="n">uid</span><span class="p">,</span>
<span class="n">metadata</span><span class="o">.</span><span class="n">empty</span><span class="p">(),</span>
<span class="n">java_models_array</span><span class="p">,</span>
<span class="p">)</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">&quot;classifier&quot;</span><span class="p">,</span> <span class="n">cast</span><span class="p">(</span><span class="n">_JavaClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">())</span><span class="o">.</span><span class="n">_to_java</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">&quot;featuresCol&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getFeaturesCol</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">&quot;labelCol&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getLabelCol</span><span class="p">())</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">&quot;predictionCol&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getPredictionCol</span><span class="p">())</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">isDefined</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weightCol</span><span class="p">)</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">():</span>
<span class="n">_java_obj</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="s2">&quot;weightCol&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">getWeightCol</span><span class="p">())</span>
<span class="k">return</span> <span class="n">_java_obj</span>
<div class="viewcode-block" id="OneVsRestModel.read"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.read">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">read</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;OneVsRestModelReader&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">OneVsRestModelReader</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span></div>
<div class="viewcode-block" id="OneVsRestModel.write"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.OneVsRestModel.html#pyspark.ml.classification.OneVsRestModel.write">[docs]</a> <span class="k">def</span> <span class="nf">write</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">MLWriter</span><span class="p">:</span>
<span class="k">if</span> <span class="nb">all</span><span class="p">(</span>
<span class="nb">map</span><span class="p">(</span>
<span class="k">lambda</span> <span class="n">elem</span><span class="p">:</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">elem</span><span class="p">,</span> <span class="n">JavaMLWritable</span><span class="p">),</span>
<span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">getClassifier</span><span class="p">()]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">models</span><span class="p">,</span> <span class="c1"># type: ignore[operator]</span>
<span class="p">)</span>
<span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="n">OneVsRestModelWriter</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span></div></div>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">OneVsRestModelReader</span><span class="p">(</span><span class="n">MLReader</span><span class="p">[</span><span class="n">OneVsRestModel</span><span class="p">]):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">cls</span><span class="p">:</span> <span class="n">Type</span><span class="p">[</span><span class="n">OneVsRestModel</span><span class="p">]):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModelReader</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">cls</span> <span class="o">=</span> <span class="bp">cls</span>
<span class="k">def</span> <span class="nf">load</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">OneVsRestModel</span><span class="p">:</span>
<span class="n">metadata</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadMetadata</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">isPythonParamsInstance</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span>
<span class="k">return</span> <span class="n">JavaMLReader</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">cls</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># type: ignore[arg-type]</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">classifier</span> <span class="o">=</span> <span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">loadClassifier</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="n">numClasses</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;numClasses&quot;</span><span class="p">]</span>
<span class="n">subModels</span> <span class="o">=</span> <span class="p">[</span><span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">numClasses</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">):</span>
<span class="n">subModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;model_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">subModels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">loadParamsInstance</span><span class="p">(</span><span class="n">subModelPath</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">)</span>
<span class="n">ovaModel</span> <span class="o">=</span> <span class="n">OneVsRestModel</span><span class="p">(</span><span class="n">cast</span><span class="p">(</span><span class="n">List</span><span class="p">[</span><span class="n">ClassificationModel</span><span class="p">],</span> <span class="n">subModels</span><span class="p">))</span><span class="o">.</span><span class="n">_resetUid</span><span class="p">(</span>
<span class="n">metadata</span><span class="p">[</span><span class="s2">&quot;uid&quot;</span><span class="p">]</span>
<span class="p">)</span>
<span class="n">ovaModel</span><span class="o">.</span><span class="n">set</span><span class="p">(</span><span class="n">ovaModel</span><span class="o">.</span><span class="n">classifier</span><span class="p">,</span> <span class="n">classifier</span><span class="p">)</span>
<span class="n">DefaultParamsReader</span><span class="o">.</span><span class="n">getAndSetParams</span><span class="p">(</span><span class="n">ovaModel</span><span class="p">,</span> <span class="n">metadata</span><span class="p">,</span> <span class="n">skipParams</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;classifier&quot;</span><span class="p">])</span>
<span class="k">return</span> <span class="n">ovaModel</span>
<span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">OneVsRestModelWriter</span><span class="p">(</span><span class="n">MLWriter</span><span class="p">):</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">instance</span><span class="p">:</span> <span class="n">OneVsRestModel</span><span class="p">):</span>
<span class="nb">super</span><span class="p">(</span><span class="n">OneVsRestModelWriter</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">instance</span> <span class="o">=</span> <span class="n">instance</span>
<span class="k">def</span> <span class="nf">saveImpl</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">path</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">validateParams</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">instance</span><span class="p">)</span>
<span class="n">instance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">instance</span>
<span class="n">numClasses</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">)</span>
<span class="n">extraMetadata</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;numClasses&quot;</span><span class="p">:</span> <span class="n">numClasses</span><span class="p">}</span>
<span class="n">_OneVsRestSharedReadWrite</span><span class="o">.</span><span class="n">saveImpl</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">sc</span><span class="p">,</span> <span class="n">path</span><span class="p">,</span> <span class="n">extraMetadata</span><span class="o">=</span><span class="n">extraMetadata</span><span class="p">)</span>
<span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">numClasses</span><span class="p">):</span>
<span class="n">subModelPath</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="sa">f</span><span class="s2">&quot;model_</span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
<span class="n">cast</span><span class="p">(</span><span class="n">MLWritable</span><span class="p">,</span> <span class="n">instance</span><span class="o">.</span><span class="n">models</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">subModelPath</span><span class="p">)</span>
<div class="viewcode-block" id="FMClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">FMClassifier</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassifier</span><span class="p">[</span><span class="s2">&quot;FMClassificationModel&quot;</span><span class="p">],</span>
<span class="n">_FactorizationMachinesParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;FMClassifier&quot;</span><span class="p">],</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Factorization Machines learning algorithm for classification.</span>
<span class="sd"> Solver supports:</span>
<span class="sd"> * gd (normal mini-batch gradient descent)</span>
<span class="sd"> * adamW (default)</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.linalg import Vectors</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.ml.classification import FMClassifier</span>
<span class="sd"> &gt;&gt;&gt; df = spark.createDataFrame([</span>
<span class="sd"> ... (1.0, Vectors.dense(1.0)),</span>
<span class="sd"> ... (0.0, Vectors.sparse(1, [], []))], [&quot;label&quot;, &quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; fm = FMClassifier(factorSize=2)</span>
<span class="sd"> &gt;&gt;&gt; fm.setSeed(11)</span>
<span class="sd"> FMClassifier...</span>
<span class="sd"> &gt;&gt;&gt; model = fm.fit(df)</span>
<span class="sd"> &gt;&gt;&gt; model.getMaxIter()</span>
<span class="sd"> 100</span>
<span class="sd"> &gt;&gt;&gt; test0 = spark.createDataFrame([</span>
<span class="sd"> ... (Vectors.dense(-1.0),),</span>
<span class="sd"> ... (Vectors.dense(0.5),),</span>
<span class="sd"> ... (Vectors.dense(1.0),),</span>
<span class="sd"> ... (Vectors.dense(2.0),)], [&quot;features&quot;])</span>
<span class="sd"> &gt;&gt;&gt; model.predictRaw(test0.head().features)</span>
<span class="sd"> DenseVector([22.13..., -22.13...])</span>
<span class="sd"> &gt;&gt;&gt; model.predictProbability(test0.head().features)</span>
<span class="sd"> DenseVector([1.0, 0.0])</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).select(&quot;features&quot;, &quot;probability&quot;).show(10, False)</span>
<span class="sd"> +--------+------------------------------------------+</span>
<span class="sd"> |features|probability |</span>
<span class="sd"> +--------+------------------------------------------+</span>
<span class="sd"> |[-1.0] |[0.9999999997574736,2.425264676902229E-10]|</span>
<span class="sd"> |[0.5] |[0.47627851732981163,0.5237214826701884] |</span>
<span class="sd"> |[1.0] |[5.491554426243495E-4,0.9994508445573757] |</span>
<span class="sd"> |[2.0] |[2.005766663870645E-10,0.9999999997994233]|</span>
<span class="sd"> +--------+------------------------------------------+</span>
<span class="sd"> ...</span>
<span class="sd"> &gt;&gt;&gt; model.intercept</span>
<span class="sd"> -7.316665276826291</span>
<span class="sd"> &gt;&gt;&gt; model.linear</span>
<span class="sd"> DenseVector([14.8232])</span>
<span class="sd"> &gt;&gt;&gt; model.factors</span>
<span class="sd"> DenseMatrix(1, 2, [0.0163, -0.0051], 1)</span>
<span class="sd"> &gt;&gt;&gt; model_path = temp_path + &quot;/fm_model&quot;</span>
<span class="sd"> &gt;&gt;&gt; model.save(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2 = FMClassificationModel.load(model_path)</span>
<span class="sd"> &gt;&gt;&gt; model2.intercept</span>
<span class="sd"> -7.316665276826291</span>
<span class="sd"> &gt;&gt;&gt; model2.linear</span>
<span class="sd"> DenseVector([14.8232])</span>
<span class="sd"> &gt;&gt;&gt; model2.factors</span>
<span class="sd"> DenseMatrix(1, 2, [0.0163, -0.0051], 1)</span>
<span class="sd"> &gt;&gt;&gt; model.transform(test0).take(1) == model2.transform(test0).take(1)</span>
<span class="sd"> True</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">_input_kwargs</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="n">Any</span><span class="p">]</span>
<span class="nd">@keyword_only</span>
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">factorSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">fitLinear</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">miniBatchFraction</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">initStd</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;adamW&quot;</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> __init__(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \</span>
<span class="sd"> miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \</span>
<span class="sd"> tol=1e-6, solver=&quot;adamW&quot;, thresholds=None, seed=None)</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nb">super</span><span class="p">(</span><span class="n">FMClassifier</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>
<span class="bp">self</span><span class="o">.</span><span class="n">_java_obj</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_new_java_obj</span><span class="p">(</span>
<span class="s2">&quot;org.apache.spark.ml.classification.FMClassifier&quot;</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">uid</span>
<span class="p">)</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="bp">self</span><span class="o">.</span><span class="n">setParams</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
<div class="viewcode-block" id="FMClassifier.setParams"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setParams">[docs]</a> <span class="nd">@keyword_only</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setParams</span><span class="p">(</span>
<span class="bp">self</span><span class="p">,</span>
<span class="o">*</span><span class="p">,</span>
<span class="n">featuresCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;features&quot;</span><span class="p">,</span>
<span class="n">labelCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;label&quot;</span><span class="p">,</span>
<span class="n">predictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;prediction&quot;</span><span class="p">,</span>
<span class="n">probabilityCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;probability&quot;</span><span class="p">,</span>
<span class="n">rawPredictionCol</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;rawPrediction&quot;</span><span class="p">,</span>
<span class="n">factorSize</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span>
<span class="n">fitIntercept</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">fitLinear</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">,</span>
<span class="n">regParam</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="n">miniBatchFraction</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">initStd</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">,</span>
<span class="n">maxIter</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">stepSize</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1.0</span><span class="p">,</span>
<span class="n">tol</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">1e-6</span><span class="p">,</span>
<span class="n">solver</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;adamW&quot;</span><span class="p">,</span>
<span class="n">thresholds</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="n">List</span><span class="p">[</span><span class="nb">float</span><span class="p">]]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> setParams(self, \\*, featuresCol=&quot;features&quot;, labelCol=&quot;label&quot;, predictionCol=&quot;prediction&quot;, \</span>
<span class="sd"> probabilityCol=&quot;probability&quot;, rawPredictionCol=&quot;rawPrediction&quot;, \</span>
<span class="sd"> factorSize=8, fitIntercept=True, fitLinear=True, regParam=0.0, \</span>
<span class="sd"> miniBatchFraction=1.0, initStd=0.01, maxIter=100, stepSize=1.0, \</span>
<span class="sd"> tol=1e-6, solver=&quot;adamW&quot;, thresholds=None, seed=None)</span>
<span class="sd"> Sets Params for FMClassifier.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">kwargs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_kwargs</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span></div>
<span class="k">def</span> <span class="nf">_create_model</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">java_model</span><span class="p">:</span> <span class="s2">&quot;JavaObject&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassificationModel&quot;</span><span class="p">:</span>
<span class="k">return</span> <span class="n">FMClassificationModel</span><span class="p">(</span><span class="n">java_model</span><span class="p">)</span>
<div class="viewcode-block" id="FMClassifier.setFactorSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFactorSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFactorSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`factorSize`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">factorSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setFitLinear"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFitLinear">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFitLinear</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`fitLinear`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitLinear</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setMiniBatchFraction"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setMiniBatchFraction">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMiniBatchFraction</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`miniBatchFraction`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">miniBatchFraction</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setInitStd"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setInitStd">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setInitStd</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`initStd`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">initStd</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setMaxIter"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setMaxIter">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setMaxIter</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`maxIter`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setStepSize"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setStepSize">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setStepSize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`stepSize`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">stepSize</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setTol"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setTol">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setTol</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`tol`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">tol</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setSolver"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setSolver">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSolver</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">str</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`solver`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">solver</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setSeed"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setSeed">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setSeed</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">int</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`seed`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">seed</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setFitIntercept"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setFitIntercept">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setFitIntercept</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">bool</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`fitIntercept`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">fitIntercept</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div>
<div class="viewcode-block" id="FMClassifier.setRegParam"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassifier.html#pyspark.ml.classification.FMClassifier.setRegParam">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">setRegParam</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">value</span><span class="p">:</span> <span class="nb">float</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassifier&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Sets the value of :py:attr:`regParam`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_set</span><span class="p">(</span><span class="n">regParam</span><span class="o">=</span><span class="n">value</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="FMClassificationModel"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel">[docs]</a><span class="k">class</span> <span class="nc">FMClassificationModel</span><span class="p">(</span>
<span class="n">_JavaProbabilisticClassificationModel</span><span class="p">[</span><span class="n">Vector</span><span class="p">],</span>
<span class="n">_FactorizationMachinesParams</span><span class="p">,</span>
<span class="n">JavaMLWritable</span><span class="p">,</span>
<span class="n">JavaMLReadable</span><span class="p">[</span><span class="s2">&quot;FMClassificationModel&quot;</span><span class="p">],</span>
<span class="n">HasTrainingSummary</span><span class="p">,</span>
<span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model fitted by :class:`FMClassifier`.</span>
<span class="sd"> .. versionadded:: 3.0.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">intercept</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model intercept.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;intercept&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">linear</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Vector</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model linear term.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;linear&quot;</span><span class="p">)</span>
<span class="nd">@property</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.0.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">factors</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">Matrix</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Model factor term.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;factors&quot;</span><span class="p">)</span>
<div class="viewcode-block" id="FMClassificationModel.summary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel.summary">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;3.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassificationTrainingSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Gets summary (accuracy/precision/recall, objective history, total iterations) of model</span>
<span class="sd"> trained on the training set. An exception is thrown if `trainingSummary is None`.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">hasSummary</span><span class="p">:</span>
<span class="k">return</span> <span class="n">FMClassificationTrainingSummary</span><span class="p">(</span><span class="nb">super</span><span class="p">(</span><span class="n">FMClassificationModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span>
<span class="s2">&quot;No training summary available for this </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="FMClassificationModel.evaluate"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationModel.html#pyspark.ml.classification.FMClassificationModel.evaluate">[docs]</a> <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">DataFrame</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="s2">&quot;FMClassificationSummary&quot;</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Evaluates the model on a test dataset.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> dataset : :py:class:`pyspark.sql.DataFrame`</span>
<span class="sd"> Test dataset to evaluate model on.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">DataFrame</span><span class="p">):</span>
<span class="k">raise</span> <span class="ne">TypeError</span><span class="p">(</span><span class="s2">&quot;dataset must be a DataFrame but got </span><span class="si">%s</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="nb">type</span><span class="p">(</span><span class="n">dataset</span><span class="p">))</span>
<span class="n">java_fm_summary</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_call_java</span><span class="p">(</span><span class="s2">&quot;evaluate&quot;</span><span class="p">,</span> <span class="n">dataset</span><span class="p">)</span>
<span class="k">return</span> <span class="n">FMClassificationSummary</span><span class="p">(</span><span class="n">java_fm_summary</span><span class="p">)</span></div></div>
<div class="viewcode-block" id="FMClassificationSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationSummary.html#pyspark.ml.classification.FMClassificationSummary">[docs]</a><span class="k">class</span> <span class="nc">FMClassificationSummary</span><span class="p">(</span><span class="n">_BinaryClassificationSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for FMClassifier Results for a given model.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<div class="viewcode-block" id="FMClassificationTrainingSummary"><a class="viewcode-back" href="../../../reference/api/pyspark.ml.classification.FMClassificationTrainingSummary.html#pyspark.ml.classification.FMClassificationTrainingSummary">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">FMClassificationTrainingSummary</span><span class="p">(</span><span class="n">FMClassificationSummary</span><span class="p">,</span> <span class="n">_TrainingSummary</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Abstraction for FMClassifier Training results.</span>
<span class="sd"> .. versionadded:: 3.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">pass</span></div>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">doctest</span>
<span class="kn">import</span> <span class="nn">pyspark.ml.classification</span>
<span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">SparkSession</span>
<span class="n">globs</span> <span class="o">=</span> <span class="n">pyspark</span><span class="o">.</span><span class="n">ml</span><span class="o">.</span><span class="n">classification</span><span class="o">.</span><span class="vm">__dict__</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="c1"># The small batch size here ensures that we see multiple batches,</span>
<span class="c1"># even in these small test examples:</span>
<span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">&quot;local[2]&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">&quot;ml.classification tests&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span>
<span class="n">sc</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;sc&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">sc</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;spark&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span>
<span class="kn">import</span> <span class="nn">tempfile</span>
<span class="n">temp_path</span> <span class="o">=</span> <span class="n">tempfile</span><span class="o">.</span><span class="n">mkdtemp</span><span class="p">()</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;temp_path&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">temp_path</span>
<span class="k">try</span><span class="p">:</span>
<span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span><span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span> <span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span><span class="p">)</span>
<span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span>
<span class="k">finally</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">shutil</span> <span class="kn">import</span> <span class="n">rmtree</span>
<span class="k">try</span><span class="p">:</span>
<span class="n">rmtree</span><span class="p">(</span><span class="n">temp_path</span><span class="p">)</span>
<span class="k">except</span> <span class="ne">OSError</span><span class="p">:</span>
<span class="k">pass</span>
<span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span>
<span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
</pre></div>
</article>
<footer class="bd-footer-article">
<div class="footer-article-items footer-article__inner">
<div class="footer-article-item"><!-- Previous / next buttons -->
<div class="prev-next-area">
</div></div>
</div>
</footer>
</div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script src="../../../_static/scripts/bootstrap.js?digest=e353d410970836974a52"></script>
<script src="../../../_static/scripts/pydata-sphinx-theme.js?digest=e353d410970836974a52"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item"><p class="copyright">
Copyright @ 2024 The Apache Software Foundation, Licensed under the <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache License, Version 2.0</a>.
</p></div>
<div class="footer-item">
<p class="sphinx-version">
Created using <a href="https://www.sphinx-doc.org/">Sphinx</a> 4.5.0.
<br/>
</p>
</div>
</div>
<div class="footer-items__end">
<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.3.
</p></div>
</div>
</div>
</footer>
</body>
</html>