blob: 48b88bfcd98d06d83fee0d115a7922275ee5cafe [file] [log] [blame]
<!DOCTYPE html>
<html >
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>pyspark.mllib.tree &#8212; PySpark 4.0.0-preview2 documentation</title>
<script data-cfasync="false">
document.documentElement.dataset.mode = localStorage.getItem("mode") || "";
document.documentElement.dataset.theme = localStorage.getItem("theme") || "light";
</script>
<!-- Loaded before other Sphinx assets -->
<link href="../../../_static/styles/theme.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/styles/bootstrap.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/styles/pydata-sphinx-theme.css?digest=e353d410970836974a52" rel="stylesheet" />
<link href="../../../_static/vendor/fontawesome/6.1.2/css/all.min.css?digest=e353d410970836974a52" rel="stylesheet" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-solid-900.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-brands-400.woff2" />
<link rel="preload" as="font" type="font/woff2" crossorigin href="../../../_static/vendor/fontawesome/6.1.2/webfonts/fa-regular-400.woff2" />
<link rel="stylesheet" type="text/css" href="../../../_static/pygments.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/copybutton.css" />
<link rel="stylesheet" type="text/css" href="../../../_static/css/pyspark.css" />
<!-- Pre-loaded scripts that we'll load fully later -->
<link rel="preload" as="script" href="../../../_static/scripts/bootstrap.js?digest=e353d410970836974a52" />
<link rel="preload" as="script" href="../../../_static/scripts/pydata-sphinx-theme.js?digest=e353d410970836974a52" />
<script data-url_root="../../../" id="documentation_options" src="../../../_static/documentation_options.js"></script>
<script src="../../../_static/jquery.js"></script>
<script src="../../../_static/underscore.js"></script>
<script src="../../../_static/doctools.js"></script>
<script src="../../../_static/clipboard.min.js"></script>
<script src="../../../_static/copybutton.js"></script>
<script crossorigin="anonymous" integrity="sha256-Ae2Vz/4ePdIu6ZyI/5ZGsYnb+m0JlOmKPjt6XZ9JJkA=" src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.4/require.min.js"></script>
<script>DOCUMENTATION_OPTIONS.pagename = '_modules/pyspark/mllib/tree';</script>
<link rel="canonical" href="https://spark.apache.org/docs/latest/api/python/_modules/pyspark/mllib/tree.html" />
<link rel="search" title="Search" href="../../../search.html" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<meta name="docsearch:language" content="None">
<!-- Matomo -->
<script type="text/javascript">
var _paq = window._paq = window._paq || [];
/* tracker methods like "setCustomDimension" should be called before "trackPageView" */
_paq.push(["disableCookies"]);
_paq.push(['trackPageView']);
_paq.push(['enableLinkTracking']);
(function() {
var u="https://analytics.apache.org/";
_paq.push(['setTrackerUrl', u+'matomo.php']);
_paq.push(['setSiteId', '40']);
var d=document, g=d.createElement('script'), s=d.getElementsByTagName('script')[0];
g.async=true; g.src=u+'matomo.js'; s.parentNode.insertBefore(g,s);
})();
</script>
<!-- End Matomo Code -->
</head>
<body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode="">
<a class="skip-link" href="#main-content">Skip to main content</a>
<input type="checkbox"
class="sidebar-toggle"
name="__primary"
id="__primary"/>
<label class="overlay overlay-primary" for="__primary"></label>
<input type="checkbox"
class="sidebar-toggle"
name="__secondary"
id="__secondary"/>
<label class="overlay overlay-secondary" for="__secondary"></label>
<div class="search-button__wrapper">
<div class="search-button__overlay"></div>
<div class="search-button__search-container">
<form class="bd-search d-flex align-items-center"
action="../../../search.html"
method="get">
<i class="fa-solid fa-magnifying-glass"></i>
<input type="search"
class="form-control"
name="q"
id="search-input"
placeholder="Search the docs ..."
aria-label="Search the docs ..."
autocomplete="off"
autocorrect="off"
autocapitalize="off"
spellcheck="false"/>
<span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span>
</form></div>
</div>
<nav class="bd-header navbar navbar-expand-lg bd-navbar">
<div class="bd-header__inner bd-page-width">
<label class="sidebar-toggle primary-toggle" for="__primary">
<span class="fa-solid fa-bars"></span>
</label>
<div class="navbar-header-items__start">
<div class="navbar-item">
<a class="navbar-brand logo" href="../../../index.html">
<img src="https://spark.apache.org/images/spark-logo.png" class="logo__image only-light" alt="Logo image"/>
<script>document.write(`<img src="https://spark.apache.org/images/spark-logo-rev.svg" class="logo__image only-dark" alt="Logo image"/>`);</script>
</a></div>
</div>
<div class="col-lg-9 navbar-header-items">
<div class="me-auto navbar-header-items__center">
<div class="navbar-item"><nav class="navbar-nav">
<p class="sidebar-header-items__title"
role="heading"
aria-level="1"
aria-label="Site Navigation">
Site Navigation
</p>
<ul class="bd-navbar-elements navbar-nav">
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../index.html">
Overview
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../development/index.html">
Development
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</nav></div>
</div>
<div class="navbar-header-items__end">
<div class="navbar-item navbar-persistent--container">
<script>
document.write(`
<button class="btn btn-sm navbar-btn search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
</button>
`);
</script>
</div>
<div class="navbar-item"><!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
4.0.0-preview2
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/mllib/tree.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script></div>
<div class="navbar-item">
<script>
document.write(`
<button class="theme-switch-button btn btn-sm btn-outline-primary navbar-btn rounded-circle" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="theme-switch" data-mode="light"><i class="fa-solid fa-sun"></i></span>
<span class="theme-switch" data-mode="dark"><i class="fa-solid fa-moon"></i></span>
<span class="theme-switch" data-mode="auto"><i class="fa-solid fa-circle-half-stroke"></i></span>
</button>
`);
</script></div>
<div class="navbar-item"><ul class="navbar-icon-links navbar-nav"
aria-label="Icon Links">
<li class="nav-item">
<a href="https://github.com/apache/spark" title="GitHub" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-brands fa-github"></i></span>
<label class="sr-only">GitHub</label></a>
</li>
<li class="nav-item">
<a href="https://pypi.org/project/pyspark" title="PyPI" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-solid fa-box"></i></span>
<label class="sr-only">PyPI</label></a>
</li>
</ul></div>
</div>
</div>
<div class="navbar-persistent--mobile">
<script>
document.write(`
<button class="btn btn-sm navbar-btn search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip">
<i class="fa-solid fa-magnifying-glass"></i>
</button>
`);
</script>
</div>
</div>
</nav>
<div class="bd-container">
<div class="bd-container__inner bd-page-width">
<div class="bd-sidebar-primary bd-sidebar hide-on-wide">
<div class="sidebar-header-items sidebar-primary__section">
<div class="sidebar-header-items__center">
<div class="navbar-item"><nav class="navbar-nav">
<p class="sidebar-header-items__title"
role="heading"
aria-level="1"
aria-label="Site Navigation">
Site Navigation
</p>
<ul class="bd-navbar-elements navbar-nav">
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../index.html">
Overview
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../getting_started/index.html">
Getting Started
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../user_guide/index.html">
User Guides
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../reference/index.html">
API Reference
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../development/index.html">
Development
</a>
</li>
<li class="nav-item">
<a class="nav-link nav-internal" href="../../../migration_guide/index.html">
Migration Guides
</a>
</li>
</ul>
</nav></div>
</div>
<div class="sidebar-header-items__end">
<div class="navbar-item"><!--
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
<div id="version-button" class="dropdown">
<button type="button" class="btn btn-secondary btn-sm navbar-btn dropdown-toggle" id="version_switcher_button" data-toggle="dropdown">
4.0.0-preview2
<span class="caret"></span>
</button>
<div id="version_switcher" class="dropdown-menu list-group-flush py-0" aria-labelledby="version_switcher_button">
<!-- dropdown will be populated by javascript on page load -->
</div>
</div>
<script type="text/javascript">
// Function to construct the target URL from the JSON components
function buildURL(entry) {
var template = "https://spark.apache.org/docs/{version}/api/python/index.html"; // supplied by jinja
template = template.replace("{version}", entry.version);
return template;
}
// Function to check if corresponding page path exists in other version of docs
// and, if so, go there instead of the homepage of the other docs version
function checkPageExistsAndRedirect(event) {
const currentFilePath = "_modules/pyspark/mllib/tree.html",
otherDocsHomepage = event.target.getAttribute("href");
let tryUrl = `${otherDocsHomepage}${currentFilePath}`;
$.ajax({
type: 'HEAD',
url: tryUrl,
// if the page exists, go there
success: function() {
location.href = tryUrl;
}
}).fail(function() {
location.href = otherDocsHomepage;
});
return false;
}
// Function to populate the version switcher
(function () {
// get JSON config
$.getJSON("https://spark.apache.org/static/versions.json", function(data, textStatus, jqXHR) {
// create the nodes first (before AJAX calls) to ensure the order is
// correct (for now, links will go to doc version homepage)
$.each(data, function(index, entry) {
// if no custom name specified (e.g., "latest"), use version string
if (!("name" in entry)) {
entry.name = entry.version;
}
// construct the appropriate URL, and add it to the dropdown
entry.url = buildURL(entry);
const node = document.createElement("a");
node.setAttribute("class", "list-group-item list-group-item-action py-1");
node.setAttribute("href", `${entry.url}`);
node.textContent = `${entry.name}`;
node.onclick = checkPageExistsAndRedirect;
$("#version_switcher").append(node);
});
});
})();
</script></div>
<div class="navbar-item">
<script>
document.write(`
<button class="theme-switch-button btn btn-sm btn-outline-primary navbar-btn rounded-circle" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip">
<span class="theme-switch" data-mode="light"><i class="fa-solid fa-sun"></i></span>
<span class="theme-switch" data-mode="dark"><i class="fa-solid fa-moon"></i></span>
<span class="theme-switch" data-mode="auto"><i class="fa-solid fa-circle-half-stroke"></i></span>
</button>
`);
</script></div>
<div class="navbar-item"><ul class="navbar-icon-links navbar-nav"
aria-label="Icon Links">
<li class="nav-item">
<a href="https://github.com/apache/spark" title="GitHub" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-brands fa-github"></i></span>
<label class="sr-only">GitHub</label></a>
</li>
<li class="nav-item">
<a href="https://pypi.org/project/pyspark" title="PyPI" class="nav-link" rel="noopener" target="_blank" data-bs-toggle="tooltip" data-bs-placement="bottom"><span><i class="fa-solid fa-box"></i></span>
<label class="sr-only">PyPI</label></a>
</li>
</ul></div>
</div>
</div>
<div class="sidebar-primary-items__end sidebar-primary__section">
</div>
<div id="rtd-footer-container"></div>
</div>
<main id="main-content" class="bd-main">
<div class="bd-content">
<div class="bd-article-container">
<div class="bd-header-article">
<div class="header-article-items header-article__inner">
<div class="header-article-items__start">
<div class="header-article-item">
<nav aria-label="Breadcrumbs">
<ul class="bd-breadcrumbs" role="navigation" aria-label="Breadcrumb">
<li class="breadcrumb-item breadcrumb-home">
<a href="../../../index.html" class="nav-link" aria-label="Home">
<i class="fa-solid fa-home"></i>
</a>
</li>
<li class="breadcrumb-item"><a href="../../index.html" class="nav-link">Module code</a></li>
<li class="breadcrumb-item active" aria-current="page">pyspark.mllib.tree</li>
</ul>
</nav>
</div>
</div>
</div>
</div>
<div id="searchbox"></div>
<article class="bd-article" role="main">
<h1>Source code for pyspark.mllib.tree</h1><div class="highlight"><pre>
<span></span><span class="c1">#</span>
<span class="c1"># Licensed to the Apache Software Foundation (ASF) under one or more</span>
<span class="c1"># contributor license agreements. See the NOTICE file distributed with</span>
<span class="c1"># this work for additional information regarding copyright ownership.</span>
<span class="c1"># The ASF licenses this file to You under the Apache License, Version 2.0</span>
<span class="c1"># (the &quot;License&quot;); you may not use this file except in compliance with</span>
<span class="c1"># the License. You may obtain a copy of the License at</span>
<span class="c1">#</span>
<span class="c1"># http://www.apache.org/licenses/LICENSE-2.0</span>
<span class="c1">#</span>
<span class="c1"># Unless required by applicable law or agreed to in writing, software</span>
<span class="c1"># distributed under the License is distributed on an &quot;AS IS&quot; BASIS,</span>
<span class="c1"># WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.</span>
<span class="c1"># See the License for the specific language governing permissions and</span>
<span class="c1"># limitations under the License.</span>
<span class="c1">#</span>
<span class="kn">import</span> <span class="nn">sys</span>
<span class="kn">import</span> <span class="nn">random</span>
<span class="kn">from</span> <span class="nn">pyspark</span> <span class="kn">import</span> <span class="n">RDD</span><span class="p">,</span> <span class="n">since</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.common</span> <span class="kn">import</span> <span class="n">callMLlibFunc</span><span class="p">,</span> <span class="n">inherit_doc</span><span class="p">,</span> <span class="n">JavaModelWrapper</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">_convert_to_vector</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.regression</span> <span class="kn">import</span> <span class="n">LabeledPoint</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">JavaLoader</span><span class="p">,</span> <span class="n">JavaSaveable</span>
<span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Dict</span><span class="p">,</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Tuple</span><span class="p">,</span> <span class="n">Union</span><span class="p">,</span> <span class="n">overload</span><span class="p">,</span> <span class="n">TYPE_CHECKING</span>
<span class="kn">from</span> <span class="nn">pyspark.core.rdd</span> <span class="kn">import</span> <span class="n">RDD</span>
<span class="k">if</span> <span class="n">TYPE_CHECKING</span><span class="p">:</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib._typing</span> <span class="kn">import</span> <span class="n">VectorLike</span>
<span class="n">__all__</span> <span class="o">=</span> <span class="p">[</span>
<span class="s2">&quot;DecisionTreeModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;DecisionTree&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForestModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;RandomForest&quot;</span><span class="p">,</span>
<span class="s2">&quot;GradientBoostedTreesModel&quot;</span><span class="p">,</span>
<span class="s2">&quot;GradientBoostedTrees&quot;</span><span class="p">,</span>
<span class="p">]</span>
<span class="k">class</span> <span class="nc">TreeEnsembleModel</span><span class="p">(</span><span class="n">JavaModelWrapper</span><span class="p">,</span> <span class="n">JavaSaveable</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;TreeEnsembleModel</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="s2">&quot;VectorLike&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="o">...</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">RDD</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="o">...</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">,</span> <span class="n">RDD</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">RDD</span><span class="p">[</span><span class="nb">float</span><span class="p">]]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Predict values for a single data point or an RDD of points using</span>
<span class="sd"> the model trained.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> In Python, predict cannot currently be used within an RDD</span>
<span class="sd"> transformation or action.</span>
<span class="sd"> Call predict directly on the RDD instead.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">RDD</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;predict&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">_convert_to_vector</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;predict&quot;</span><span class="p">,</span> <span class="n">_convert_to_vector</span><span class="p">(</span><span class="n">x</span><span class="p">))</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">numTrees</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get number of trees in ensemble.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;numTrees&quot;</span><span class="p">)</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">totalNumNodes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get total number of nodes, summed over all trees in the ensemble.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;totalNumNodes&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Summary of model&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">toString</span><span class="p">()</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.3.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">toDebugString</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Full model&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">toDebugString</span><span class="p">()</span>
<div class="viewcode-block" id="DecisionTreeModel"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTreeModel.html#pyspark.mllib.tree.DecisionTreeModel">[docs]</a><span class="k">class</span> <span class="nc">DecisionTreeModel</span><span class="p">(</span><span class="n">JavaModelWrapper</span><span class="p">,</span> <span class="n">JavaSaveable</span><span class="p">,</span> <span class="n">JavaLoader</span><span class="p">[</span><span class="s2">&quot;DecisionTreeModel&quot;</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> A decision tree model for classification or regression.</span>
<span class="sd"> .. versionadded:: 1.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="s2">&quot;VectorLike&quot;</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">float</span><span class="p">:</span>
<span class="o">...</span>
<span class="nd">@overload</span>
<span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">])</span> <span class="o">-&gt;</span> <span class="n">RDD</span><span class="p">[</span><span class="nb">float</span><span class="p">]:</span>
<span class="o">...</span>
<div class="viewcode-block" id="DecisionTreeModel.predict"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTreeModel.html#pyspark.mllib.tree.DecisionTreeModel.predict">[docs]</a> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">Union</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">,</span> <span class="n">RDD</span><span class="p">[</span><span class="s2">&quot;VectorLike&quot;</span><span class="p">]])</span> <span class="o">-&gt;</span> <span class="n">Union</span><span class="p">[</span><span class="nb">float</span><span class="p">,</span> <span class="n">RDD</span><span class="p">[</span><span class="nb">float</span><span class="p">]]:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Predict the label of one or more examples.</span>
<span class="sd"> .. versionadded:: 1.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> x : :py:class:`pyspark.mllib.linalg.Vector` or :py:class:`pyspark.RDD`</span>
<span class="sd"> Data point (feature vector), or an RDD of data points (feature</span>
<span class="sd"> vectors).</span>
<span class="sd"> Notes</span>
<span class="sd"> -----</span>
<span class="sd"> In Python, predict cannot currently be used within an RDD</span>
<span class="sd"> transformation or action.</span>
<span class="sd"> Call predict directly on the RDD instead.</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">RDD</span><span class="p">):</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;predict&quot;</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">_convert_to_vector</span><span class="p">))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="s2">&quot;predict&quot;</span><span class="p">,</span> <span class="n">_convert_to_vector</span><span class="p">(</span><span class="n">x</span><span class="p">))</span></div>
<div class="viewcode-block" id="DecisionTreeModel.numNodes"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTreeModel.html#pyspark.mllib.tree.DecisionTreeModel.numNodes">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">numNodes</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Get number of nodes in tree, including leaf nodes.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">numNodes</span><span class="p">()</span></div>
<div class="viewcode-block" id="DecisionTreeModel.depth"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTreeModel.html#pyspark.mllib.tree.DecisionTreeModel.depth">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">depth</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">int</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Get depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">depth</span><span class="p">()</span></div>
<span class="k">def</span> <span class="fm">__repr__</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;summary of model.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">toString</span><span class="p">()</span>
<div class="viewcode-block" id="DecisionTreeModel.toDebugString"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTreeModel.html#pyspark.mllib.tree.DecisionTreeModel.toDebugString">[docs]</a> <span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.2.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">toDebugString</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;full model.&quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_java_model</span><span class="o">.</span><span class="n">toDebugString</span><span class="p">()</span></div>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_java_loader_class</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="s2">&quot;org.apache.spark.mllib.tree.model.DecisionTreeModel&quot;</span></div>
<div class="viewcode-block" id="DecisionTree"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTree.html#pyspark.mllib.tree.DecisionTree">[docs]</a><span class="k">class</span> <span class="nc">DecisionTree</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Learning algorithm for a decision tree model for classification or</span>
<span class="sd"> regression.</span>
<span class="sd"> .. versionadded:: 1.1.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_train</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="nb">type</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">features</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DecisionTreeModel</span><span class="p">:</span>
<span class="n">first</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">first</span><span class="p">()</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">first</span><span class="p">,</span> <span class="n">LabeledPoint</span><span class="p">),</span> <span class="s2">&quot;the data should be RDD of LabeledPoint&quot;</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">callMLlibFunc</span><span class="p">(</span>
<span class="s2">&quot;trainDecisionTreeModel&quot;</span><span class="p">,</span>
<span class="n">data</span><span class="p">,</span>
<span class="nb">type</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">,</span>
<span class="n">features</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">DecisionTreeModel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<div class="viewcode-block" id="DecisionTree.trainClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTree.html#pyspark.mllib.tree.DecisionTree.trainClassifier">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">trainClassifier</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">numClasses</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DecisionTreeModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a decision tree model for classification.</span>
<span class="sd"> .. versionadded:: 1.1.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data : :py:class:`pyspark.RDD`</span>
<span class="sd"> Training data: RDD of LabeledPoint. Labels should take values</span>
<span class="sd"> {0, 1, ..., numClasses-1}.</span>
<span class="sd"> numClasses : int</span>
<span class="sd"> Number of classes for classification.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> impurity : str, optional</span>
<span class="sd"> Criterion used for information gain calculation.</span>
<span class="sd"> Supported values: &quot;gini&quot; or &quot;entropy&quot;.</span>
<span class="sd"> (default: &quot;gini&quot;)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 5)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Number of bins used for finding splits at each node.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> minInstancesPerNode : int, optional</span>
<span class="sd"> Minimum number of instances required at child nodes to create</span>
<span class="sd"> the parent split.</span>
<span class="sd"> (default: 1)</span>
<span class="sd"> minInfoGain : float, optional</span>
<span class="sd"> Minimum info gain required to create a split.</span>
<span class="sd"> (default: 0.0)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`DecisionTreeModel`</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from numpy import array</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import DecisionTree</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = [</span>
<span class="sd"> ... LabeledPoint(0.0, [0.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [1.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [2.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [3.0])</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt; model = DecisionTree.trainClassifier(sc.parallelize(data), 2, {})</span>
<span class="sd"> &gt;&gt;&gt; print(model)</span>
<span class="sd"> DecisionTreeModel classifier of depth 1 with 3 nodes</span>
<span class="sd"> &gt;&gt;&gt; print(model.toDebugString())</span>
<span class="sd"> DecisionTreeModel classifier of depth 1 with 3 nodes</span>
<span class="sd"> If (feature 0 &lt;= 0.5)</span>
<span class="sd"> Predict: 0.0</span>
<span class="sd"> Else (feature 0 &gt; 0.5)</span>
<span class="sd"> Predict: 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict(array([1.0]))</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict(array([0.0]))</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[1.0], [0.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.0]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;classification&quot;</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">,</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="DecisionTree.trainRegressor"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.DecisionTree.html#pyspark.mllib.tree.DecisionTree.trainRegressor">[docs]</a> <span class="nd">@classmethod</span>
<span class="nd">@since</span><span class="p">(</span><span class="s2">&quot;1.1.0&quot;</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">trainRegressor</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">5</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.0</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">DecisionTreeModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a decision tree model for regression.</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data : :py:class:`pyspark.RDD`</span>
<span class="sd"> Training data: RDD of LabeledPoint. Labels are real numbers.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> impurity : str, optional</span>
<span class="sd"> Criterion used for information gain calculation.</span>
<span class="sd"> The only supported value for regression is &quot;variance&quot;.</span>
<span class="sd"> (default: &quot;variance&quot;)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 5)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Number of bins used for finding splits at each node.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> minInstancesPerNode : int, optional</span>
<span class="sd"> Minimum number of instances required at child nodes to create</span>
<span class="sd"> the parent split.</span>
<span class="sd"> (default: 1)</span>
<span class="sd"> minInfoGain : float, optional</span>
<span class="sd"> Minimum info gain required to create a split.</span>
<span class="sd"> (default: 0.0)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`DecisionTreeModel`</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import DecisionTree</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.linalg import SparseVector</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; sparse_data = [</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {1: 1.0}))</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {1: 0.0}))</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.0]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;regression&quot;</span><span class="p">,</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">minInstancesPerNode</span><span class="p">,</span>
<span class="n">minInfoGain</span><span class="p">,</span>
<span class="p">)</span></div></div>
<div class="viewcode-block" id="RandomForestModel"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.RandomForestModel.html#pyspark.mllib.tree.RandomForestModel">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">RandomForestModel</span><span class="p">(</span><span class="n">TreeEnsembleModel</span><span class="p">,</span> <span class="n">JavaLoader</span><span class="p">[</span><span class="s2">&quot;RandomForestModel&quot;</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Represents a random forest model.</span>
<span class="sd"> .. versionadded:: 1.2.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_java_loader_class</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="s2">&quot;org.apache.spark.mllib.tree.model.RandomForestModel&quot;</span></div>
<div class="viewcode-block" id="RandomForest"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.RandomForest.html#pyspark.mllib.tree.RandomForest">[docs]</a><span class="k">class</span> <span class="nc">RandomForest</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Learning algorithm for a random forest model for classification or</span>
<span class="sd"> regression.</span>
<span class="sd"> .. versionadded:: 1.2.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">supportedFeatureSubsetStrategies</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">str</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="s2">&quot;auto&quot;</span><span class="p">,</span> <span class="s2">&quot;all&quot;</span><span class="p">,</span> <span class="s2">&quot;sqrt&quot;</span><span class="p">,</span> <span class="s2">&quot;log2&quot;</span><span class="p">,</span> <span class="s2">&quot;onethird&quot;</span><span class="p">)</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_train</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">algo</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RandomForestModel</span><span class="p">:</span>
<span class="n">first</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">first</span><span class="p">()</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">first</span><span class="p">,</span> <span class="n">LabeledPoint</span><span class="p">),</span> <span class="s2">&quot;the data should be RDD of LabeledPoint&quot;</span>
<span class="k">if</span> <span class="n">featureSubsetStrategy</span> <span class="ow">not</span> <span class="ow">in</span> <span class="bp">cls</span><span class="o">.</span><span class="n">supportedFeatureSubsetStrategies</span><span class="p">:</span>
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;unsupported featureSubsetStrategy: </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="n">featureSubsetStrategy</span><span class="p">)</span>
<span class="k">if</span> <span class="n">seed</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">seed</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span> <span class="o">&lt;&lt;</span> <span class="mi">30</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">callMLlibFunc</span><span class="p">(</span>
<span class="s2">&quot;trainRandomForestModel&quot;</span><span class="p">,</span>
<span class="n">data</span><span class="p">,</span>
<span class="n">algo</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">numTrees</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">RandomForestModel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<div class="viewcode-block" id="RandomForest.trainClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.RandomForest.html#pyspark.mllib.tree.RandomForest.trainClassifier">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">trainClassifier</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">numClasses</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;gini&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RandomForestModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a random forest model for binary or multiclass</span>
<span class="sd"> classification.</span>
<span class="sd"> .. versionadded:: 1.2.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data : :py:class:`pyspark.RDD`</span>
<span class="sd"> Training dataset: RDD of LabeledPoint. Labels should take values</span>
<span class="sd"> {0, 1, ..., numClasses-1}.</span>
<span class="sd"> numClasses : int</span>
<span class="sd"> Number of classes for classification.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> numTrees : int</span>
<span class="sd"> Number of trees in the random forest.</span>
<span class="sd"> featureSubsetStrategy : str, optional</span>
<span class="sd"> Number of features to consider for splits at each node.</span>
<span class="sd"> Supported values: &quot;auto&quot;, &quot;all&quot;, &quot;sqrt&quot;, &quot;log2&quot;, &quot;onethird&quot;.</span>
<span class="sd"> If &quot;auto&quot; is set, this parameter is set based on numTrees:</span>
<span class="sd"> if numTrees == 1, set to &quot;all&quot;;</span>
<span class="sd"> if numTrees &gt; 1 (forest) set to &quot;sqrt&quot;.</span>
<span class="sd"> (default: &quot;auto&quot;)</span>
<span class="sd"> impurity : str, optional</span>
<span class="sd"> Criterion used for information gain calculation.</span>
<span class="sd"> Supported values: &quot;gini&quot; or &quot;entropy&quot;.</span>
<span class="sd"> (default: &quot;gini&quot;)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 4)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Maximum number of bins used for splitting features.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> seed : int, Optional</span>
<span class="sd"> Random seed for bootstrapping and choosing feature subsets.</span>
<span class="sd"> Set as None to generate seed based on system time.</span>
<span class="sd"> (default: None)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`RandomForestModel`</span>
<span class="sd"> that can be used for prediction.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import RandomForest</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = [</span>
<span class="sd"> ... LabeledPoint(0.0, [0.0]),</span>
<span class="sd"> ... LabeledPoint(0.0, [1.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [2.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [3.0])</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt; model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)</span>
<span class="sd"> &gt;&gt;&gt; model.numTrees()</span>
<span class="sd"> 3</span>
<span class="sd"> &gt;&gt;&gt; model.totalNumNodes()</span>
<span class="sd"> 7</span>
<span class="sd"> &gt;&gt;&gt; print(model)</span>
<span class="sd"> TreeEnsembleModel classifier with 3 trees</span>
<span class="sd"> &gt;&gt;&gt; print(model.toDebugString())</span>
<span class="sd"> TreeEnsembleModel classifier with 3 trees</span>
<span class="sd"> Tree 0:</span>
<span class="sd"> Predict: 1.0</span>
<span class="sd"> Tree 1:</span>
<span class="sd"> If (feature 0 &lt;= 1.5)</span>
<span class="sd"> Predict: 0.0</span>
<span class="sd"> Else (feature 0 &gt; 1.5)</span>
<span class="sd"> Predict: 1.0</span>
<span class="sd"> Tree 2:</span>
<span class="sd"> If (feature 0 &lt;= 1.5)</span>
<span class="sd"> Predict: 0.0</span>
<span class="sd"> Else (feature 0 &gt; 1.5)</span>
<span class="sd"> Predict: 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict([2.0])</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict([0.0])</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[3.0], [1.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.0]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;classification&quot;</span><span class="p">,</span>
<span class="n">numClasses</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">numTrees</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="RandomForest.trainRegressor"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.RandomForest.html#pyspark.mllib.tree.RandomForest.trainRegressor">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">trainRegressor</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">numTrees</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;auto&quot;</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;variance&quot;</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">4</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="n">seed</span><span class="p">:</span> <span class="n">Optional</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="o">=</span> <span class="kc">None</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">RandomForestModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a random forest model for regression.</span>
<span class="sd"> .. versionadded:: 1.2.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data : :py:class:`pyspark.RDD`</span>
<span class="sd"> Training dataset: RDD of LabeledPoint. Labels are real numbers.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> numTrees : int</span>
<span class="sd"> Number of trees in the random forest.</span>
<span class="sd"> featureSubsetStrategy : str, optional</span>
<span class="sd"> Number of features to consider for splits at each node.</span>
<span class="sd"> Supported values: &quot;auto&quot;, &quot;all&quot;, &quot;sqrt&quot;, &quot;log2&quot;, &quot;onethird&quot;.</span>
<span class="sd"> If &quot;auto&quot; is set, this parameter is set based on numTrees:</span>
<span class="sd"> - if numTrees == 1, set to &quot;all&quot;;</span>
<span class="sd"> - if numTrees &gt; 1 (forest) set to &quot;onethird&quot; for regression.</span>
<span class="sd"> (default: &quot;auto&quot;)</span>
<span class="sd"> impurity : str, optional</span>
<span class="sd"> Criterion used for information gain calculation.</span>
<span class="sd"> The only supported value for regression is &quot;variance&quot;.</span>
<span class="sd"> (default: &quot;variance&quot;)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 4)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Maximum number of bins used for splitting features.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> seed : int, optional</span>
<span class="sd"> Random seed for bootstrapping and choosing feature subsets.</span>
<span class="sd"> Set as None to generate seed based on system time.</span>
<span class="sd"> (default: None)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`RandomForestModel`</span>
<span class="sd"> that can be used for prediction.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import RandomForest</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.linalg import SparseVector</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; sparse_data = [</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)</span>
<span class="sd"> &gt;&gt;&gt; model.numTrees()</span>
<span class="sd"> 2</span>
<span class="sd"> &gt;&gt;&gt; model.totalNumNodes()</span>
<span class="sd"> 4</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {1: 1.0}))</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {0: 1.0}))</span>
<span class="sd"> 0.5</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.5]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;regression&quot;</span><span class="p">,</span>
<span class="mi">0</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">numTrees</span><span class="p">,</span>
<span class="n">featureSubsetStrategy</span><span class="p">,</span>
<span class="n">impurity</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="n">seed</span><span class="p">,</span>
<span class="p">)</span></div></div>
<div class="viewcode-block" id="GradientBoostedTreesModel"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.GradientBoostedTreesModel.html#pyspark.mllib.tree.GradientBoostedTreesModel">[docs]</a><span class="nd">@inherit_doc</span>
<span class="k">class</span> <span class="nc">GradientBoostedTreesModel</span><span class="p">(</span><span class="n">TreeEnsembleModel</span><span class="p">,</span> <span class="n">JavaLoader</span><span class="p">[</span><span class="s2">&quot;GradientBoostedTreesModel&quot;</span><span class="p">]):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Represents a gradient-boosted tree model.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_java_loader_class</span><span class="p">(</span><span class="bp">cls</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="nb">str</span><span class="p">:</span>
<span class="k">return</span> <span class="s2">&quot;org.apache.spark.mllib.tree.model.GradientBoostedTreesModel&quot;</span></div>
<div class="viewcode-block" id="GradientBoostedTrees"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.GradientBoostedTrees.html#pyspark.mllib.tree.GradientBoostedTrees">[docs]</a><span class="k">class</span> <span class="nc">GradientBoostedTrees</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Learning algorithm for a gradient boosted trees model for</span>
<span class="sd"> classification or regression.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">_train</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">algo</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">loss</span><span class="p">:</span> <span class="nb">str</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">:</span> <span class="nb">float</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">GradientBoostedTreesModel</span><span class="p">:</span>
<span class="n">first</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">first</span><span class="p">()</span>
<span class="k">assert</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">first</span><span class="p">,</span> <span class="n">LabeledPoint</span><span class="p">),</span> <span class="s2">&quot;the data should be RDD of LabeledPoint&quot;</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">callMLlibFunc</span><span class="p">(</span>
<span class="s2">&quot;trainGradientBoostedTreesModel&quot;</span><span class="p">,</span>
<span class="n">data</span><span class="p">,</span>
<span class="n">algo</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">loss</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="p">)</span>
<span class="k">return</span> <span class="n">GradientBoostedTreesModel</span><span class="p">(</span><span class="n">model</span><span class="p">)</span>
<div class="viewcode-block" id="GradientBoostedTrees.trainClassifier"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.GradientBoostedTrees.html#pyspark.mllib.tree.GradientBoostedTrees.trainClassifier">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">trainClassifier</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">loss</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;logLoss&quot;</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">GradientBoostedTreesModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a gradient-boosted trees model for classification.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data : :py:class:`pyspark.RDD`</span>
<span class="sd"> Training dataset: RDD of LabeledPoint. Labels should take values</span>
<span class="sd"> {0, 1}.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> loss : str, optional</span>
<span class="sd"> Loss function used for minimization during gradient boosting.</span>
<span class="sd"> Supported values: &quot;logLoss&quot;, &quot;leastSquaresError&quot;,</span>
<span class="sd"> &quot;leastAbsoluteError&quot;.</span>
<span class="sd"> (default: &quot;logLoss&quot;)</span>
<span class="sd"> numIterations : int, optional</span>
<span class="sd"> Number of iterations of boosting.</span>
<span class="sd"> (default: 100)</span>
<span class="sd"> learningRate : float, optional</span>
<span class="sd"> Learning rate for shrinking the contribution of each estimator.</span>
<span class="sd"> The learning rate should be between in the interval (0, 1].</span>
<span class="sd"> (default: 0.1)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 3)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Maximum number of bins used for splitting features. DecisionTree</span>
<span class="sd"> requires maxBins &gt;= max categories.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`GradientBoostedTreesModel`</span>
<span class="sd"> that can be used for prediction.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import GradientBoostedTrees</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = [</span>
<span class="sd"> ... LabeledPoint(0.0, [0.0]),</span>
<span class="sd"> ... LabeledPoint(0.0, [1.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [2.0]),</span>
<span class="sd"> ... LabeledPoint(1.0, [3.0])</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10)</span>
<span class="sd"> &gt;&gt;&gt; model.numTrees()</span>
<span class="sd"> 10</span>
<span class="sd"> &gt;&gt;&gt; model.totalNumNodes()</span>
<span class="sd"> 30</span>
<span class="sd"> &gt;&gt;&gt; print(model) # it already has newline</span>
<span class="sd"> TreeEnsembleModel classifier with 10 trees</span>
<span class="sd"> &gt;&gt;&gt; model.predict([2.0])</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict([0.0])</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[2.0], [0.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.0]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;classification&quot;</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">loss</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="p">)</span></div>
<div class="viewcode-block" id="GradientBoostedTrees.trainRegressor"><a class="viewcode-back" href="../../../reference/api/pyspark.mllib.tree.GradientBoostedTrees.html#pyspark.mllib.tree.GradientBoostedTrees.trainRegressor">[docs]</a> <span class="nd">@classmethod</span>
<span class="k">def</span> <span class="nf">trainRegressor</span><span class="p">(</span>
<span class="bp">cls</span><span class="p">,</span>
<span class="n">data</span><span class="p">:</span> <span class="n">RDD</span><span class="p">[</span><span class="n">LabeledPoint</span><span class="p">],</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">:</span> <span class="n">Dict</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="nb">int</span><span class="p">],</span>
<span class="n">loss</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;leastSquaresError&quot;</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">100</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">3</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">32</span><span class="p">,</span>
<span class="p">)</span> <span class="o">-&gt;</span> <span class="n">GradientBoostedTreesModel</span><span class="p">:</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Train a gradient-boosted trees model for regression.</span>
<span class="sd"> .. versionadded:: 1.3.0</span>
<span class="sd"> Parameters</span>
<span class="sd"> ----------</span>
<span class="sd"> data :</span>
<span class="sd"> Training dataset: RDD of LabeledPoint. Labels are real numbers.</span>
<span class="sd"> categoricalFeaturesInfo : dict</span>
<span class="sd"> Map storing arity of categorical features. An entry (n -&gt; k)</span>
<span class="sd"> indicates that feature n is categorical with k categories</span>
<span class="sd"> indexed from 0: {0, 1, ..., k-1}.</span>
<span class="sd"> loss : str, optional</span>
<span class="sd"> Loss function used for minimization during gradient boosting.</span>
<span class="sd"> Supported values: &quot;logLoss&quot;, &quot;leastSquaresError&quot;,</span>
<span class="sd"> &quot;leastAbsoluteError&quot;.</span>
<span class="sd"> (default: &quot;leastSquaresError&quot;)</span>
<span class="sd"> numIterations : int, optional</span>
<span class="sd"> Number of iterations of boosting.</span>
<span class="sd"> (default: 100)</span>
<span class="sd"> learningRate : float, optional</span>
<span class="sd"> Learning rate for shrinking the contribution of each estimator.</span>
<span class="sd"> The learning rate should be between in the interval (0, 1].</span>
<span class="sd"> (default: 0.1)</span>
<span class="sd"> maxDepth : int, optional</span>
<span class="sd"> Maximum depth of tree (e.g. depth 0 means 1 leaf node, depth 1</span>
<span class="sd"> means 1 internal node + 2 leaf nodes).</span>
<span class="sd"> (default: 3)</span>
<span class="sd"> maxBins : int, optional</span>
<span class="sd"> Maximum number of bins used for splitting features. DecisionTree</span>
<span class="sd"> requires maxBins &gt;= max categories.</span>
<span class="sd"> (default: 32)</span>
<span class="sd"> Returns</span>
<span class="sd"> -------</span>
<span class="sd"> :py:class:`GradientBoostedTreesModel`</span>
<span class="sd"> that can be used for prediction.</span>
<span class="sd"> Examples</span>
<span class="sd"> --------</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.regression import LabeledPoint</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.tree import GradientBoostedTrees</span>
<span class="sd"> &gt;&gt;&gt; from pyspark.mllib.linalg import SparseVector</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; sparse_data = [</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),</span>
<span class="sd"> ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),</span>
<span class="sd"> ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))</span>
<span class="sd"> ... ]</span>
<span class="sd"> &gt;&gt;&gt;</span>
<span class="sd"> &gt;&gt;&gt; data = sc.parallelize(sparse_data)</span>
<span class="sd"> &gt;&gt;&gt; model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10)</span>
<span class="sd"> &gt;&gt;&gt; model.numTrees()</span>
<span class="sd"> 10</span>
<span class="sd"> &gt;&gt;&gt; model.totalNumNodes()</span>
<span class="sd"> 12</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {1: 1.0}))</span>
<span class="sd"> 1.0</span>
<span class="sd"> &gt;&gt;&gt; model.predict(SparseVector(2, {0: 1.0}))</span>
<span class="sd"> 0.0</span>
<span class="sd"> &gt;&gt;&gt; rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])</span>
<span class="sd"> &gt;&gt;&gt; model.predict(rdd).collect()</span>
<span class="sd"> [1.0, 0.0]</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="k">return</span> <span class="bp">cls</span><span class="o">.</span><span class="n">_train</span><span class="p">(</span>
<span class="n">data</span><span class="p">,</span>
<span class="s2">&quot;regression&quot;</span><span class="p">,</span>
<span class="n">categoricalFeaturesInfo</span><span class="p">,</span>
<span class="n">loss</span><span class="p">,</span>
<span class="n">numIterations</span><span class="p">,</span>
<span class="n">learningRate</span><span class="p">,</span>
<span class="n">maxDepth</span><span class="p">,</span>
<span class="n">maxBins</span><span class="p">,</span>
<span class="p">)</span></div></div>
<span class="k">def</span> <span class="nf">_test</span><span class="p">()</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span>
<span class="kn">import</span> <span class="nn">doctest</span>
<span class="n">globs</span> <span class="o">=</span> <span class="nb">globals</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span>
<span class="kn">from</span> <span class="nn">pyspark.sql</span> <span class="kn">import</span> <span class="n">SparkSession</span>
<span class="n">spark</span> <span class="o">=</span> <span class="n">SparkSession</span><span class="o">.</span><span class="n">builder</span><span class="o">.</span><span class="n">master</span><span class="p">(</span><span class="s2">&quot;local[4]&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">appName</span><span class="p">(</span><span class="s2">&quot;mllib.tree tests&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">getOrCreate</span><span class="p">()</span>
<span class="n">globs</span><span class="p">[</span><span class="s2">&quot;sc&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">sparkContext</span>
<span class="p">(</span><span class="n">failure_count</span><span class="p">,</span> <span class="n">test_count</span><span class="p">)</span> <span class="o">=</span> <span class="n">doctest</span><span class="o">.</span><span class="n">testmod</span><span class="p">(</span>
<span class="n">globs</span><span class="o">=</span><span class="n">globs</span><span class="p">,</span> <span class="n">optionflags</span><span class="o">=</span><span class="n">doctest</span><span class="o">.</span><span class="n">ELLIPSIS</span> <span class="o">|</span> <span class="n">doctest</span><span class="o">.</span><span class="n">NORMALIZE_WHITESPACE</span>
<span class="p">)</span>
<span class="n">spark</span><span class="o">.</span><span class="n">stop</span><span class="p">()</span>
<span class="k">if</span> <span class="n">failure_count</span><span class="p">:</span>
<span class="n">sys</span><span class="o">.</span><span class="n">exit</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
<span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">&quot;__main__&quot;</span><span class="p">:</span>
<span class="n">_test</span><span class="p">()</span>
</pre></div>
</article>
<footer class="bd-footer-article">
<div class="footer-article-items footer-article__inner">
<div class="footer-article-item"><!-- Previous / next buttons -->
<div class="prev-next-area">
</div></div>
</div>
</footer>
</div>
</div>
<footer class="bd-footer-content">
</footer>
</main>
</div>
</div>
<!-- Scripts loaded after <body> so the DOM is not blocked -->
<script src="../../../_static/scripts/bootstrap.js?digest=e353d410970836974a52"></script>
<script src="../../../_static/scripts/pydata-sphinx-theme.js?digest=e353d410970836974a52"></script>
<footer class="bd-footer">
<div class="bd-footer__inner bd-page-width">
<div class="footer-items__start">
<div class="footer-item"><p class="copyright">
Copyright @ 2024 The Apache Software Foundation, Licensed under the <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache License, Version 2.0</a>.
</p></div>
<div class="footer-item">
<p class="sphinx-version">
Created using <a href="https://www.sphinx-doc.org/">Sphinx</a> 4.5.0.
<br/>
</p>
</div>
</div>
<div class="footer-items__end">
<div class="footer-item"><p class="theme-version">
Built with the <a href="https://pydata-sphinx-theme.readthedocs.io/en/stable/index.html">PyData Sphinx Theme</a> 0.13.3.
</p></div>
</div>
</div>
</footer>
</body>
</html>