blob: 615a11f716fe3d3c232dfd087a486444d83074e3 [file] [log] [blame]
<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>Classification and regression - spark.ml - Spark 2.0.0 Documentation</title>
<link rel="stylesheet" href="css/bootstrap.min.css">
<style>
body {
padding-top: 60px;
padding-bottom: 40px;
}
</style>
<meta name="viewport" content="width=device-width">
<link rel="stylesheet" href="css/bootstrap-responsive.min.css">
<link rel="stylesheet" href="css/main.css">
<script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script>
<link rel="stylesheet" href="css/pygments-default.css">
</head>
<body>
<!--[if lt IE 7]>
<p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p>
<![endif]-->
<!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html -->
<div class="navbar navbar-fixed-top" id="topbar">
<div class="navbar-inner">
<div class="container">
<div class="brand"><a href="index.html">
<img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">2.0.0</span>
</div>
<ul class="nav">
<!--TODO(andyk): Add class="active" attribute to li some how.-->
<li><a href="index.html">Overview</a></li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="programming-guide.html">Spark Programming Guide</a></li>
<li class="divider"></li>
<li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
<li><a href="sql-programming-guide.html">DataFrames, Datasets and SQL</a></li>
<li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li>
<li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li>
<li><a href="sparkr.html">SparkR (R on Spark)</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li>
<li><a href="api/java/index.html">Java</a></li>
<li><a href="api/python/index.html">Python</a></li>
<li><a href="api/R/index.html">R</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="cluster-overview.html">Overview</a></li>
<li><a href="submitting-applications.html">Submitting Applications</a></li>
<li class="divider"></li>
<li><a href="spark-standalone.html">Spark Standalone</a></li>
<li><a href="running-on-mesos.html">Mesos</a></li>
<li><a href="running-on-yarn.html">YARN</a></li>
</ul>
</li>
<li class="dropdown">
<a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="configuration.html">Configuration</a></li>
<li><a href="monitoring.html">Monitoring</a></li>
<li><a href="tuning.html">Tuning Guide</a></li>
<li><a href="job-scheduling.html">Job Scheduling</a></li>
<li><a href="security.html">Security</a></li>
<li><a href="hardware-provisioning.html">Hardware Provisioning</a></li>
<li class="divider"></li>
<li><a href="building-spark.html">Building Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark">Contributing to Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects">Supplemental Projects</a></li>
</ul>
</li>
</ul>
<!--<p class="navbar-text pull-right"><span class="version-text">v2.0.0</span></p>-->
</div>
</div>
</div>
<div class="container-wrapper">
<div class="left-menu-wrapper">
<div class="left-menu">
<h3><a href="ml-guide.html">spark.ml package</a></h3>
<ul>
<li>
<a href="ml-guide.html">
Overview: estimators, transformers and pipelines
</a>
</li>
<li>
<a href="ml-features.html">
Extracting, transforming and selecting features
</a>
</li>
<li>
<a href="ml-classification-regression.html">
<b>Classification and Regression</b>
</a>
</li>
<li>
<a href="ml-clustering.html">
Clustering
</a>
</li>
<li>
<a href="ml-collaborative-filtering.html">
Collaborative filtering
</a>
</li>
<li>
<a href="ml-advanced.html">
Advanced topics
</a>
</li>
</ul>
<h3><a href="mllib-guide.html">spark.mllib package</a></h3>
<ul>
<li>
<a href="mllib-data-types.html">
Data types
</a>
</li>
<li>
<a href="mllib-statistics.html">
Basic statistics
</a>
</li>
<li>
<a href="mllib-classification-regression.html">
Classification and regression
</a>
</li>
<li>
<a href="mllib-collaborative-filtering.html">
Collaborative filtering
</a>
</li>
<li>
<a href="mllib-clustering.html">
Clustering
</a>
</li>
<li>
<a href="mllib-dimensionality-reduction.html">
Dimensionality reduction
</a>
</li>
<li>
<a href="mllib-feature-extraction.html">
Feature extraction and transformation
</a>
</li>
<li>
<a href="mllib-frequent-pattern-mining.html">
Frequent pattern mining
</a>
</li>
<li>
<a href="mllib-evaluation-metrics.html">
Evaluation metrics
</a>
</li>
<li>
<a href="mllib-pmml-model-export.html">
PMML model export
</a>
</li>
<li>
<a href="mllib-optimization.html">
Optimization (developer)
</a>
</li>
</ul>
</div>
</div>
<input id="nav-trigger" class="nav-trigger" checked type="checkbox">
<label for="nav-trigger"></label>
<div class="content-with-sidebar" id="content">
<h1 class="title">Classification and regression - spark.ml</h1>
<p><code>\[
\newcommand{\R}{\mathbb{R}}
\newcommand{\E}{\mathbb{E}}
\newcommand{\x}{\mathbf{x}}
\newcommand{\y}{\mathbf{y}}
\newcommand{\wv}{\mathbf{w}}
\newcommand{\av}{\mathbf{\alpha}}
\newcommand{\bv}{\mathbf{b}}
\newcommand{\N}{\mathbb{N}}
\newcommand{\id}{\mathbf{I}}
\newcommand{\ind}{\mathbf{1}}
\newcommand{\0}{\mathbf{0}}
\newcommand{\unit}{\mathbf{e}}
\newcommand{\one}{\mathbf{1}}
\newcommand{\zero}{\mathbf{0}}
\]</code></p>
<p><strong>Table of Contents</strong></p>
<ul id="markdown-toc">
<li><a href="#classification" id="markdown-toc-classification">Classification</a> <ul>
<li><a href="#logistic-regression" id="markdown-toc-logistic-regression">Logistic regression</a></li>
<li><a href="#decision-tree-classifier" id="markdown-toc-decision-tree-classifier">Decision tree classifier</a></li>
<li><a href="#random-forest-classifier" id="markdown-toc-random-forest-classifier">Random forest classifier</a></li>
<li><a href="#gradient-boosted-tree-classifier" id="markdown-toc-gradient-boosted-tree-classifier">Gradient-boosted tree classifier</a></li>
<li><a href="#multilayer-perceptron-classifier" id="markdown-toc-multilayer-perceptron-classifier">Multilayer perceptron classifier</a></li>
<li><a href="#one-vs-rest-classifier-aka-one-vs-all" id="markdown-toc-one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</a></li>
<li><a href="#naive-bayes" id="markdown-toc-naive-bayes">Naive Bayes</a></li>
</ul>
</li>
<li><a href="#regression" id="markdown-toc-regression">Regression</a> <ul>
<li><a href="#linear-regression" id="markdown-toc-linear-regression">Linear regression</a></li>
<li><a href="#decision-tree-regression" id="markdown-toc-decision-tree-regression">Decision tree regression</a></li>
<li><a href="#random-forest-regression" id="markdown-toc-random-forest-regression">Random forest regression</a></li>
<li><a href="#gradient-boosted-tree-regression" id="markdown-toc-gradient-boosted-tree-regression">Gradient-boosted tree regression</a></li>
<li><a href="#survival-regression" id="markdown-toc-survival-regression">Survival regression</a></li>
</ul>
</li>
<li><a href="#decision-trees" id="markdown-toc-decision-trees">Decision trees</a> <ul>
<li><a href="#inputs-and-outputs" id="markdown-toc-inputs-and-outputs">Inputs and Outputs</a> <ul>
<li><a href="#input-columns" id="markdown-toc-input-columns">Input Columns</a></li>
<li><a href="#output-columns" id="markdown-toc-output-columns">Output Columns</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#tree-ensembles" id="markdown-toc-tree-ensembles">Tree Ensembles</a> <ul>
<li><a href="#random-forests" id="markdown-toc-random-forests">Random Forests</a> <ul>
<li><a href="#inputs-and-outputs-1" id="markdown-toc-inputs-and-outputs-1">Inputs and Outputs</a> <ul>
<li><a href="#input-columns-1" id="markdown-toc-input-columns-1">Input Columns</a></li>
<li><a href="#output-columns-predictions" id="markdown-toc-output-columns-predictions">Output Columns (Predictions)</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#gradient-boosted-trees-gbts" id="markdown-toc-gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</a> <ul>
<li><a href="#inputs-and-outputs-2" id="markdown-toc-inputs-and-outputs-2">Inputs and Outputs</a> <ul>
<li><a href="#input-columns-2" id="markdown-toc-input-columns-2">Input Columns</a></li>
<li><a href="#output-columns-predictions-1" id="markdown-toc-output-columns-predictions-1">Output Columns (Predictions)</a></li>
</ul>
</li>
</ul>
</li>
</ul>
</li>
</ul>
<p>In <code>spark.ml</code>, we implement popular linear methods such as logistic
regression and linear least squares with $L_1$ or $L_2$ regularization.
Refer to <a href="mllib-linear-methods.html">the linear methods in mllib</a> for
details about implementation and tuning. We also include a DataFrame API for <a href="http://en.wikipedia.org/wiki/Elastic_net_regularization">Elastic
net</a>, a hybrid
of $L_1$ and $L_2$ regularization proposed in <a href="http://users.stat.umn.edu/~zouxx019/Papers/elasticnet.pdf">Zou et al, Regularization
and variable selection via the elastic
net</a>.
Mathematically, it is defined as a convex combination of the $L_1$ and
the $L_2$ regularization terms:
<code>\[
\alpha \left( \lambda \|\wv\|_1 \right) + (1-\alpha) \left( \frac{\lambda}{2}\|\wv\|_2^2 \right) , \alpha \in [0, 1], \lambda \geq 0
\]</code>
By setting $\alpha$ properly, elastic net contains both $L_1$ and $L_2$
regularization as special cases. For example, if a <a href="https://en.wikipedia.org/wiki/Linear_regression">linear
regression</a> model is
trained with the elastic net parameter $\alpha$ set to $1$, it is
equivalent to a
<a href="http://en.wikipedia.org/wiki/Least_squares#Lasso_method">Lasso</a> model.
On the other hand, if $\alpha$ is set to $0$, the trained model reduces
to a <a href="http://en.wikipedia.org/wiki/Tikhonov_regularization">ridge
regression</a> model.
We implement Pipelines API for both linear regression and logistic
regression with elastic net regularization.</p>
<h1 id="classification">Classification</h1>
<h2 id="logistic-regression">Logistic regression</h2>
<p>Logistic regression is a popular method to predict a binary response. It is a special case of <a href="https://en.wikipedia.org/wiki/Generalized_linear_model">Generalized Linear models</a> that predicts the probability of the outcome.
For more background and more details about the implementation, refer to the documentation of the <a href="mllib-linear-methods.html#logistic-regression">logistic regression in <code>spark.mllib</code></a>.</p>
<blockquote>
<p>The current implementation of logistic regression in <code>spark.ml</code> only supports binary classes. Support for multiclass regression will be added in the future.</p>
</blockquote>
<p><strong>Example</strong></p>
<p>The following example shows how to train a logistic regression model
with elastic net regularization. <code>elasticNetParam</code> corresponds to
$\alpha$ and <code>regParam</code> corresponds to $\lambda$.</p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span>
<span class="c1">// Load training data</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span>
<span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span>
<span class="c1">// Fit the model</span>
<span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span>
<span class="c1">// Print the coefficients and intercept for logistic regression</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}&quot;</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load training data</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="n">LogisticRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span>
<span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span>
<span class="c1">// Fit the model</span>
<span class="n">LogisticRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Print the coefficients and intercept for logistic regression</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Coefficients: &quot;</span>
<span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">&quot; Intercept: &quot;</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionWithElasticNetExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span>
<span class="c"># Load training data</span>
<span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="c"># Fit the model</span>
<span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c"># Print the coefficients and intercept for logistic regression</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Coefficients: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Intercept: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/logistic_regression_with_elastic_net.py" in the Spark repo.</small></div>
</div>
</div>
<p>The <code>spark.ml</code> implementation of logistic regression also supports
extracting a summary of the model over the training set. Note that the
predictions and metrics which are stored as <code>DataFrame</code> in
<code>BinaryLogisticRegressionSummary</code> are annotated <code>@transient</code> and hence
only available on the driver.</p>
<div class="codetabs">
<div data-lang="scala">
<p><a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionTrainingSummary"><code>LogisticRegressionTrainingSummary</code></a>
provides a summary for a
<a href="api/scala/index.html#org.apache.spark.ml.classification.LogisticRegressionModel"><code>LogisticRegressionModel</code></a>.
Currently, only binary classification is supported and the
summary must be explicitly cast to
<a href="api/scala/index.html#org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"><code>BinaryLogisticRegressionTrainingSummary</code></a>.
This will likely change when multiclass classification is supported.</p>
<p>Continuing the earlier example:</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">BinaryLogisticRegressionSummary</span><span class="o">,</span> <span class="nc">LogisticRegression</span><span class="o">}</span>
<span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span>
<span class="c1">// example</span>
<span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span>
<span class="c1">// Obtain the objective per iteration.</span>
<span class="k">val</span> <span class="n">objectiveHistory</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">objectiveHistory</span>
<span class="n">objectiveHistory</span><span class="o">.</span><span class="n">foreach</span><span class="o">(</span><span class="n">loss</span> <span class="k">=&gt;</span> <span class="n">println</span><span class="o">(</span><span class="n">loss</span><span class="o">))</span>
<span class="c1">// Obtain the metrics useful to judge performance on test data.</span>
<span class="c1">// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a</span>
<span class="c1">// binary classification problem.</span>
<span class="k">val</span> <span class="n">binarySummary</span> <span class="k">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">BinaryLogisticRegressionSummary</span><span class="o">]</span>
<span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span>
<span class="k">val</span> <span class="n">roc</span> <span class="k">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="n">roc</span>
<span class="n">roc</span><span class="o">.</span><span class="n">show</span><span class="o">()</span>
<span class="n">println</span><span class="o">(</span><span class="n">binarySummary</span><span class="o">.</span><span class="n">areaUnderROC</span><span class="o">)</span>
<span class="c1">// Set the model threshold to maximize F-Measure</span>
<span class="k">val</span> <span class="n">fMeasure</span> <span class="k">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="n">fMeasureByThreshold</span>
<span class="k">val</span> <span class="n">maxFMeasure</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="n">max</span><span class="o">(</span><span class="s">&quot;F-Measure&quot;</span><span class="o">)).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span>
<span class="k">val</span> <span class="n">bestThreshold</span> <span class="k">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="n">where</span><span class="o">(</span><span class="n">$</span><span class="s">&quot;F-Measure&quot;</span> <span class="o">===</span> <span class="n">maxFMeasure</span><span class="o">)</span>
<span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;threshold&quot;</span><span class="o">).</span><span class="n">head</span><span class="o">().</span><span class="n">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span>
<span class="n">lrModel</span><span class="o">.</span><span class="n">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p><a href="api/java/org/apache/spark/ml/classification/LogisticRegressionTrainingSummary.html"><code>LogisticRegressionTrainingSummary</code></a>
provides a summary for a
<a href="api/java/org/apache/spark/ml/classification/LogisticRegressionModel.html"><code>LogisticRegressionModel</code></a>.
Currently, only binary classification is supported and the
summary must be explicitly cast to
<a href="api/java/org/apache/spark/ml/classification/BinaryLogisticRegressionTrainingSummary.html"><code>BinaryLogisticRegressionTrainingSummary</code></a>.
This will likely change when multiclass classification is supported.</p>
<p>Continuing the earlier example:</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.BinaryLogisticRegressionSummary</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegressionTrainingSummary</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.functions</span><span class="o">;</span>
<span class="c1">// Extract the summary from the returned LogisticRegressionModel instance trained in the earlier</span>
<span class="c1">// example</span>
<span class="n">LogisticRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span>
<span class="c1">// Obtain the loss per iteration.</span>
<span class="kt">double</span><span class="o">[]</span> <span class="n">objectiveHistory</span> <span class="o">=</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">();</span>
<span class="k">for</span> <span class="o">(</span><span class="kt">double</span> <span class="n">lossPerIteration</span> <span class="o">:</span> <span class="n">objectiveHistory</span><span class="o">)</span> <span class="o">{</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">lossPerIteration</span><span class="o">);</span>
<span class="o">}</span>
<span class="c1">// Obtain the metrics useful to judge performance on test data.</span>
<span class="c1">// We cast the summary to a BinaryLogisticRegressionSummary since the problem is a binary</span>
<span class="c1">// classification problem.</span>
<span class="n">BinaryLogisticRegressionSummary</span> <span class="n">binarySummary</span> <span class="o">=</span>
<span class="o">(</span><span class="n">BinaryLogisticRegressionSummary</span><span class="o">)</span> <span class="n">trainingSummary</span><span class="o">;</span>
<span class="c1">// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">roc</span> <span class="o">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="na">roc</span><span class="o">();</span>
<span class="n">roc</span><span class="o">.</span><span class="na">show</span><span class="o">();</span>
<span class="n">roc</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;FPR&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">();</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">binarySummary</span><span class="o">.</span><span class="na">areaUnderROC</span><span class="o">());</span>
<span class="c1">// Get the threshold corresponding to the maximum F-Measure and rerun LogisticRegression with</span>
<span class="c1">// this selected threshold.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">fMeasure</span> <span class="o">=</span> <span class="n">binarySummary</span><span class="o">.</span><span class="na">fMeasureByThreshold</span><span class="o">();</span>
<span class="kt">double</span> <span class="n">maxFMeasure</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="n">functions</span><span class="o">.</span><span class="na">max</span><span class="o">(</span><span class="s">&quot;F-Measure&quot;</span><span class="o">)).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">bestThreshold</span> <span class="o">=</span> <span class="n">fMeasure</span><span class="o">.</span><span class="na">where</span><span class="o">(</span><span class="n">fMeasure</span><span class="o">.</span><span class="na">col</span><span class="o">(</span><span class="s">&quot;F-Measure&quot;</span><span class="o">).</span><span class="na">equalTo</span><span class="o">(</span><span class="n">maxFMeasure</span><span class="o">))</span>
<span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;threshold&quot;</span><span class="o">).</span><span class="na">head</span><span class="o">().</span><span class="na">getDouble</span><span class="o">(</span><span class="mi">0</span><span class="o">);</span>
<span class="n">lrModel</span><span class="o">.</span><span class="na">setThreshold</span><span class="o">(</span><span class="n">bestThreshold</span><span class="o">);</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLogisticRegressionSummaryExample.java" in the Spark repo.</small></div>
</div>
<!--- TODO: Add python model summaries once implemented -->
<div data-lang="python">
<p>Logistic regression model summary is not yet supported in Python.</p>
</div>
</div>
<h2 id="decision-tree-classifier">Decision tree classifier</h2>
<p>Decision trees are a popular family of classification and regression methods.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p>
<p><strong>Example</strong></p>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.classification.DecisionTreeClassifier">Scala API documentation</a>.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with &gt; 4 distinct values are treated as continuous.</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a DecisionTree model.</span>
<span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span>
<span class="c1">// Chain indexers and tree in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span>
<span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeClassificationModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned classification tree model:\n&quot;</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/classification/DecisionTreeClassifier.html">Java API documentation</a>.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.DecisionTreeClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span>
<span class="o">.</span><span class="na">read</span><span class="o">()</span>
<span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span> <span class="c1">// features with &gt; 4 distinct values are treated as continuous.</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a DecisionTree model.</span>
<span class="n">DecisionTreeClassifier</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">DecisionTreeClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span>
<span class="c1">// Chain indexers and tree in a Pipeline.</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span>
<span class="n">DecisionTreeClassificationModel</span> <span class="n">treeModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">DecisionTreeClassificationModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned classification tree model:\n&quot;</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeClassificationExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.DecisionTreeClassifier">Python API documentation</a>.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">DecisionTreeClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Index labels, adding metadata to the label column.</span>
<span class="c"># Fit on whole dataset to include all labels in index.</span>
<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># We specify maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a DecisionTree model.</span>
<span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexers and tree in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexers.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Test Error = </span><span class="si">%g</span><span class="s"> &quot;</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span>
<span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="c"># summary only</span>
<span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/decision_tree_classification_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="random-forest-classifier">Random forest classifier</h2>
<p>Random forests are a popular family of classification and regression methods.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p>
<p><strong>Example</strong></p>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">RandomForestClassificationModel</span><span class="o">,</span> <span class="nc">RandomForestClassifier</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setNumTrees</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span>
<span class="c1">// Chain indexers and forest in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span>
<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestClassificationModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned classification forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/RandomForestClassifier.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="n">RandomForestClassifier</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span>
<span class="c1">// Chain indexers and forest in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span>
<span class="n">RandomForestClassificationModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned classification forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestClassifierExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Index labels, adding metadata to the label column.</span>
<span class="c"># Fit on whole dataset to include all labels in index.</span>
<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a RandomForest model.</span>
<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexers and forest in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexers.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Test Error = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span>
<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c"># summary only</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/random_forest_classifier_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="gradient-boosted-tree-classifier">Gradient-boosted tree classifier</h2>
<p>Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p>
<p><strong>Example</strong></p>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">GBTClassificationModel</span><span class="o">,</span> <span class="nc">GBTClassifier</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a GBT model.</span>
<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span>
<span class="c1">// Chain indexers and GBT in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span>
<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTClassificationModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned classification GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/GBTClassifier.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span>
<span class="o">.</span><span class="na">read</span><span class="o">()</span>
<span class="o">.</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a GBT model.</span>
<span class="n">GBTClassifier</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span>
<span class="c1">// Chain indexers and GBT in a Pipeline.</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span>
<span class="n">GBTClassificationModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned classification GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeClassifierExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">GBTClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Index labels, adding metadata to the label column.</span>
<span class="c"># Fit on whole dataset to include all labels in index.</span>
<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a GBT model.</span>
<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c"># Chain indexers and GBT in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexers.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Test Error = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">))</span>
<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c"># summary only</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_classifier_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="multilayer-perceptron-classifier">Multilayer perceptron classifier</h2>
<p>Multilayer perceptron classifier (MLPC) is a classifier based on the <a href="https://en.wikipedia.org/wiki/Feedforward_neural_network">feedforward artificial neural network</a>.
MLPC consists of multiple layers of nodes.
Each layer is fully connected to the next layer in the network. Nodes in the input layer represent the input data. All other nodes maps inputs to the outputs
by performing linear combination of the inputs with the node&#8217;s weights <code>$\wv$</code> and bias <code>$\bv$</code> and applying an activation function.
It can be written in matrix form for MLPC with <code>$K+1$</code> layers as follows:
<code>\[
\mathrm{y}(\x) = \mathrm{f_K}(...\mathrm{f_2}(\wv_2^T\mathrm{f_1}(\wv_1^T \x+b_1)+b_2)...+b_K)
\]</code>
Nodes in intermediate layers use sigmoid (logistic) function:
<code>\[
\mathrm{f}(z_i) = \frac{1}{1 + e^{-z_i}}
\]</code>
Nodes in the output layer use softmax function:
<code>\[
\mathrm{f}(z_i) = \frac{e^{z_i}}{\sum_{k=1}^N e^{z_k}}
\]</code>
The number of nodes <code>$N$</code> in the output layer corresponds to the number of classes.</p>
<p>MLPC employs backpropagation for learning the model. We use logistic loss function for optimization and L-BFGS as optimization routine.</p>
<p><strong>Example</strong></p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Split the data into train and test</span>
<span class="k">val</span> <span class="n">splits</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">),</span> <span class="n">seed</span> <span class="k">=</span> <span class="mi">1234L</span><span class="o">)</span>
<span class="k">val</span> <span class="n">train</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">0</span><span class="o">)</span>
<span class="k">val</span> <span class="n">test</span> <span class="k">=</span> <span class="n">splits</span><span class="o">(</span><span class="mi">1</span><span class="o">)</span>
<span class="c1">// specify layers for the neural network:</span>
<span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span>
<span class="c1">// and output of size 3 (classes)</span>
<span class="k">val</span> <span class="n">layers</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">[</span><span class="kt">Int</span><span class="o">](</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">)</span>
<span class="c1">// create the trainer and set its parameters</span>
<span class="k">val</span> <span class="n">trainer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MultilayerPerceptronClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span>
<span class="o">.</span><span class="n">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span>
<span class="o">.</span><span class="n">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">)</span>
<span class="c1">// train the model</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span>
<span class="c1">// compute precision on the test set</span>
<span class="k">val</span> <span class="n">result</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="k">val</span> <span class="n">predictionAndLabels</span> <span class="k">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Precision:&quot;</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.MultilayerPerceptronClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="c1">// Load training data</span>
<span class="n">String</span> <span class="n">path</span> <span class="o">=</span> <span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">;</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">dataFrame</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="n">path</span><span class="o">);</span>
<span class="c1">// Split the data into train and test</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// specify layers for the neural network:</span>
<span class="c1">// input layer of size 4 (features), two intermediate of size 5 and 4</span>
<span class="c1">// and output of size 3 (classes)</span>
<span class="kt">int</span><span class="o">[]</span> <span class="n">layers</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">int</span><span class="o">[]</span> <span class="o">{</span><span class="mi">4</span><span class="o">,</span> <span class="mi">5</span><span class="o">,</span> <span class="mi">4</span><span class="o">,</span> <span class="mi">3</span><span class="o">};</span>
<span class="c1">// create the trainer and set its parameters</span>
<span class="n">MultilayerPerceptronClassifier</span> <span class="n">trainer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MultilayerPerceptronClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLayers</span><span class="o">(</span><span class="n">layers</span><span class="o">)</span>
<span class="o">.</span><span class="na">setBlockSize</span><span class="o">(</span><span class="mi">128</span><span class="o">)</span>
<span class="o">.</span><span class="na">setSeed</span><span class="o">(</span><span class="mi">1234L</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">100</span><span class="o">);</span>
<span class="c1">// train the model</span>
<span class="n">MultilayerPerceptronClassificationModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span>
<span class="c1">// compute precision on the test set</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">);</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Precision = &quot;</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">));</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaMultilayerPerceptronClassifierExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">MultilayerPerceptronClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># Load training data</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span>\
<span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Split the data into train and test</span>
<span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="c"># specify layers for the neural network:</span>
<span class="c"># input layer of size 4 (features), two intermediate of size 5 and 4</span>
<span class="c"># and output of size 3 (classes)</span>
<span class="n">layers</span> <span class="o">=</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">]</span>
<span class="c"># create the trainer and set its parameters</span>
<span class="n">trainer</span> <span class="o">=</span> <span class="n">MultilayerPerceptronClassifier</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">layers</span><span class="o">=</span><span class="n">layers</span><span class="p">,</span> <span class="n">blockSize</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">1234</span><span class="p">)</span>
<span class="c"># train the model</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">trainer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span>
<span class="c"># compute precision on the test set</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">)</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Precision:&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/multilayer_perceptron_classification.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="one-vs-rest-classifier-aka-one-vs-all">One-vs-Rest classifier (a.k.a. One-vs-All)</h2>
<p><a href="http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest">OneVsRest</a> is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as &#8220;One-vs-All.&#8221;</p>
<p><code>OneVsRest</code> is implemented as an <code>Estimator</code>. For the base classifier it takes instances of <code>Classifier</code> and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.</p>
<p>Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.</p>
<p><strong>Example</strong></p>
<p>The example below demonstrates how to load the
<a href="http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale">Iris dataset</a>, parse it as a DataFrame and perform multiclass classification using <code>OneVsRest</code>. The test error is calculated to measure the algorithm accuracy.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.OneVsRest">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">LogisticRegression</span><span class="o">,</span> <span class="nc">OneVsRest</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span>
<span class="c1">// load data file.</span>
<span class="k">val</span> <span class="n">inputData</span><span class="k">:</span> <span class="kt">DataFrame</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// generate the train/test split.</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">train</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">inputData</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">))</span>
<span class="c1">// instantiate the base classifier</span>
<span class="k">val</span> <span class="n">classifier</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="n">setTol</span><span class="o">(</span><span class="mi">1</span><span class="n">E</span><span class="o">-</span><span class="mi">6</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">)</span>
<span class="c1">// instantiate the One Vs Rest Classifier.</span>
<span class="k">val</span> <span class="n">ovr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">().</span><span class="n">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">)</span>
<span class="c1">// train the multiclass model.</span>
<span class="k">val</span> <span class="n">ovrModel</span> <span class="k">=</span> <span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span>
<span class="c1">// score the model on test data.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="c1">// obtain evaluator.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="c1">// compute the classification error on test data.</span>
<span class="k">val</span> <span class="n">precision</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;Test Error : ${1 - precision}&quot;</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/OneVsRest.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRest</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRestModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="c1">// load data file.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// generate the train/test split.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">tmp</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.8</span><span class="o">,</span> <span class="mf">0.2</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">train</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">test</span> <span class="o">=</span> <span class="n">tmp</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// configure the base classifier.</span>
<span class="n">LogisticRegression</span> <span class="n">classifier</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setTol</span><span class="o">(</span><span class="mi">1</span><span class="n">E</span><span class="o">-</span><span class="mi">6</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFitIntercept</span><span class="o">(</span><span class="kc">true</span><span class="o">);</span>
<span class="c1">// instantiate the One Vs Rest Classifier.</span>
<span class="n">OneVsRest</span> <span class="n">ovr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">OneVsRest</span><span class="o">().</span><span class="na">setClassifier</span><span class="o">(</span><span class="n">classifier</span><span class="o">);</span>
<span class="c1">// train the multiclass model.</span>
<span class="n">OneVsRestModel</span> <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span>
<span class="c1">// score the model on test data.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">);</span>
<span class="c1">// obtain evaluator.</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="c1">// compute the classification error on test data.</span>
<span class="kt">double</span> <span class="n">precision</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error : &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">precision</span><span class="o">));</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaOneVsRestExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.OneVsRest">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">LogisticRegression</span><span class="p">,</span> <span class="n">OneVsRest</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># load data file.</span>
<span class="n">inputData</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span> \
<span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="p">)</span>
<span class="c"># generate the train/test split.</span>
<span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">test</span><span class="p">)</span> <span class="o">=</span> <span class="n">inputData</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.8</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">])</span>
<span class="c"># instantiate the base classifier.</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">LogisticRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">tol</span><span class="o">=</span><span class="mf">1E-6</span><span class="p">,</span> <span class="n">fitIntercept</span><span class="o">=</span><span class="bp">True</span><span class="p">)</span>
<span class="c"># instantiate the One Vs Rest Classifier.</span>
<span class="n">ovr</span> <span class="o">=</span> <span class="n">OneVsRest</span><span class="p">(</span><span class="n">classifier</span><span class="o">=</span><span class="n">lr</span><span class="p">)</span>
<span class="c"># train the multiclass model.</span>
<span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span>
<span class="c"># score the model on test data.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="c"># obtain evaluator.</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="c"># compute the classification error on test data.</span>
<span class="n">precision</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Test Error : &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">precision</span><span class="p">))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/one_vs_rest_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="naive-bayes">Naive Bayes</h2>
<p><a href="http://en.wikipedia.org/wiki/Naive_Bayes_classifier">Naive Bayes</a> are a family of simple
probabilistic classifiers based on applying Bayes&#8217; theorem with strong (naive) independence
assumptions between the features. The spark.ml implementation currently supports both <a href="http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html">multinomial
naive Bayes</a>
and <a href="http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html">Bernoulli naive Bayes</a>.
More information can be found in the section on <a href="mllib-naive-bayes.html#naive-bayes-sparkmllib">Naive Bayes in MLlib</a>.</p>
<p><strong>Example</strong></p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.NaiveBayes">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a NaiveBayes model.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">NaiveBayes</span><span class="o">()</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">show</span><span class="o">()</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">precision</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Precision:&quot;</span> <span class="o">+</span> <span class="n">precision</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/NaiveBayes.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayes</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.NaiveBayesModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load training data</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">dataFrame</span> <span class="o">=</span>
<span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Split the data into train and test</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.6</span><span class="o">,</span> <span class="mf">0.4</span><span class="o">},</span> <span class="mi">1234L</span><span class="o">);</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// create the trainer and set its parameters</span>
<span class="n">NaiveBayes</span> <span class="n">nb</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">NaiveBayes</span><span class="o">();</span>
<span class="c1">// train the model</span>
<span class="n">NaiveBayesModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">);</span>
<span class="c1">// compute precision on the test set</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">);</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">);</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Precision = &quot;</span> <span class="o">+</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictionAndLabels</span><span class="o">));</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaNaiveBayesExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.NaiveBayes">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">NaiveBayes</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="c"># Load training data</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span> \
<span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Split the data into train and test</span>
<span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">],</span> <span class="mi">1234</span><span class="p">)</span>
<span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
<span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="c"># create the trainer and set its parameters</span>
<span class="n">nb</span> <span class="o">=</span> <span class="n">NaiveBayes</span><span class="p">(</span><span class="n">smoothing</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">modelType</span><span class="o">=</span><span class="s">&quot;multinomial&quot;</span><span class="p">)</span>
<span class="c"># train the model</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">nb</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">)</span>
<span class="c"># compute precision on the test set</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">test</span><span class="p">)</span>
<span class="n">predictionAndLabels</span> <span class="o">=</span> <span class="n">result</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">)</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span><span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Precision:&quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictionAndLabels</span><span class="p">)))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/naive_bayes_example.py" in the Spark repo.</small></div>
</div>
</div>
<h1 id="regression">Regression</h1>
<h2 id="linear-regression">Linear regression</h2>
<p>The interface for working with linear regression models and model
summaries is similar to the logistic regression case.</p>
<p><strong>Example</strong></p>
<p>The following
example demonstrates training an elastic net regularized linear
regression model and extracting model summary statistics.</p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span>
<span class="c1">// Load training data</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_linear_regression_data.txt&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">lr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">LinearRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="n">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span>
<span class="o">.</span><span class="n">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">)</span>
<span class="c1">// Fit the model</span>
<span class="k">val</span> <span class="n">lrModel</span> <span class="k">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span>
<span class="c1">// Print the coefficients and intercept for linear regression</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}&quot;</span><span class="o">)</span>
<span class="c1">// Summarize the model over the training set and print out some metrics</span>
<span class="k">val</span> <span class="n">trainingSummary</span> <span class="k">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="n">summary</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;numIterations: ${trainingSummary.totalIterations}&quot;</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;objectiveHistory: ${trainingSummary.objectiveHistory.toList}&quot;</span><span class="o">)</span>
<span class="n">trainingSummary</span><span class="o">.</span><span class="n">residuals</span><span class="o">.</span><span class="n">show</span><span class="o">()</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;RMSE: ${trainingSummary.rootMeanSquaredError}&quot;</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;r2: ${trainingSummary.r2}&quot;</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.LinearRegressionTrainingSummary</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Vectors</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load training data.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_linear_regression_data.txt&quot;</span><span class="o">);</span>
<span class="n">LinearRegression</span> <span class="n">lr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">LinearRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="o">.</span><span class="na">setRegParam</span><span class="o">(</span><span class="mf">0.3</span><span class="o">)</span>
<span class="o">.</span><span class="na">setElasticNetParam</span><span class="o">(</span><span class="mf">0.8</span><span class="o">);</span>
<span class="c1">// Fit the model.</span>
<span class="n">LinearRegressionModel</span> <span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Print the coefficients and intercept for linear regression.</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Coefficients: &quot;</span>
<span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">&quot; Intercept: &quot;</span> <span class="o">+</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">intercept</span><span class="o">());</span>
<span class="c1">// Summarize the model over the training set and print out some metrics.</span>
<span class="n">LinearRegressionTrainingSummary</span> <span class="n">trainingSummary</span> <span class="o">=</span> <span class="n">lrModel</span><span class="o">.</span><span class="na">summary</span><span class="o">();</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;numIterations: &quot;</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">totalIterations</span><span class="o">());</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;objectiveHistory: &quot;</span> <span class="o">+</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="n">trainingSummary</span><span class="o">.</span><span class="na">objectiveHistory</span><span class="o">()));</span>
<span class="n">trainingSummary</span><span class="o">.</span><span class="na">residuals</span><span class="o">().</span><span class="na">show</span><span class="o">();</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;RMSE: &quot;</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">rootMeanSquaredError</span><span class="o">());</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;r2: &quot;</span> <span class="o">+</span> <span class="n">trainingSummary</span><span class="o">.</span><span class="na">r2</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaLinearRegressionWithElasticNetExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<!--- TODO: Add python model summaries once implemented -->
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">LinearRegression</span>
<span class="c"># Load training data</span>
<span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span>\
<span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_linear_regression_data.txt&quot;</span><span class="p">)</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">LinearRegression</span><span class="p">(</span><span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">regParam</span><span class="o">=</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">elasticNetParam</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span>
<span class="c"># Fit the model</span>
<span class="n">lrModel</span> <span class="o">=</span> <span class="n">lr</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c"># Print the coefficients and intercept for linear regression</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Coefficients: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Intercept: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">lrModel</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/linear_regression_with_elastic_net.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="decision-tree-regression">Decision tree regression</h2>
<p>Decision trees are a popular family of classification and regression methods.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#decision-trees">section on decision trees</a>.</p>
<p><strong>Example</strong></p>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the Decision Tree algorithm can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>More details on parameters can be found in the <a href="api/scala/index.html#org.apache.spark.ml.regression.DecisionTreeRegressor">Scala API documentation</a>.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Here, we treat features with &gt; 4 distinct values as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a DecisionTree model.</span>
<span class="k">val</span> <span class="n">dt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">DecisionTreeRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="c1">// Chain indexer and tree in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span>
<span class="k">val</span> <span class="n">treeModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">DecisionTreeRegressionModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned regression tree model:\n&quot;</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>More details on parameters can be found in the <a href="api/java/org/apache/spark/ml/regression/DecisionTreeRegressor.html">Java API documentation</a>.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.DecisionTreeRegressor</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a DecisionTree model.</span>
<span class="n">DecisionTreeRegressor</span> <span class="n">dt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">DecisionTreeRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Chain indexer and tree in a Pipeline.</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">dt</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span>
<span class="n">DecisionTreeRegressionModel</span> <span class="n">treeModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">DecisionTreeRegressionModel</span><span class="o">)</span> <span class="o">(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned regression tree model:\n&quot;</span> <span class="o">+</span> <span class="n">treeModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaDecisionTreeRegressionExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>More details on parameters can be found in the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.DecisionTreeRegressor">Python API documentation</a>.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">DecisionTreeRegressor</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span>
<span class="c"># Load the data stored in LIBSVM format as a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># We specify maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a DecisionTree model.</span>
<span class="n">dt</span> <span class="o">=</span> <span class="n">DecisionTreeRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexer and tree in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">dt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexer.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;rmse&quot;</span><span class="p">)</span>
<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span>
<span class="n">treeModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="c"># summary only</span>
<span class="k">print</span><span class="p">(</span><span class="n">treeModel</span><span class="p">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/decision_tree_regression_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="random-forest-regression">Random forest regression</h2>
<p>Random forests are a popular family of classification and regression methods.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#random-forests">section on random forests</a>.</p>
<p><strong>Example</strong></p>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">RandomForestRegressionModel</span><span class="o">,</span> <span class="nc">RandomForestRegressor</span><span class="o">}</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="c1">// Chain indexer and forest in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestRegressionModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned regression forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/regression/RandomForestRegressor.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="n">RandomForestRegressor</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Chain indexer and forest in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span>
<span class="n">RandomForestRegressionModel</span> <span class="n">rfModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">RandomForestRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned regression forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaRandomForestRegressorExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a RandomForest model.</span>
<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexer and forest in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexer.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;rmse&quot;</span><span class="p">)</span>
<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span>
<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">rfModel</span><span class="p">)</span> <span class="c"># summary only</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/random_forest_regressor_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="gradient-boosted-tree-regression">Gradient-boosted tree regression</h2>
<p>Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees.
More information about the <code>spark.ml</code> implementation can be found further in the <a href="#gradient-boosted-trees-gbts">section on GBTs</a>.</p>
<p><strong>Example</strong></p>
<p>Note: For this example dataset, <code>GBTRegressor</code> actually only needs 1 iteration, but that will not
be true in general.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.</span><span class="o">{</span><span class="nc">GBTRegressionModel</span><span class="o">,</span> <span class="nc">GBTRegressor</span><span class="o">}</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="n">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a GBT model.</span>
<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Chain indexer and GBT in a Pipeline.</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span>
<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTRegressionModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned regression GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GBTRegressor.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">read</span><span class="o">().</span><span class="na">format</span><span class="o">(</span><span class="s">&quot;libsvm&quot;</span><span class="o">).</span><span class="na">load</span><span class="o">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing).</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a GBT model.</span>
<span class="n">GBTRegressor</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span>
<span class="c1">// Chain indexer and GBT in a Pipeline.</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">().</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error.</span>
<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span>
<span class="n">GBTRegressionModel</span> <span class="n">gbtModel</span> <span class="o">=</span> <span class="o">(</span><span class="n">GBTRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned regression GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaGradientBoostedTreeRegressorExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor">Python API docs</a> for more details.</p>
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GBTRegressor</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">read</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="s">&quot;libsvm&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a GBT model.</span>
<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c"># Chain indexer and GBT in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexer.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;rmse&quot;</span><span class="p">)</span>
<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="n">rmse</span><span class="p">)</span>
<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span><span class="p">(</span><span class="n">gbtModel</span><span class="p">)</span> <span class="c"># summary only</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/gradient_boosted_tree_regressor_example.py" in the Spark repo.</small></div>
</div>
</div>
<h2 id="survival-regression">Survival regression</h2>
<p>In <code>spark.ml</code>, we implement the <a href="https://en.wikipedia.org/wiki/Accelerated_failure_time_model">Accelerated failure time (AFT)</a>
model which is a parametric survival regression model for censored data.
It describes a model for the log of survival time, so it&#8217;s often called
log-linear model for survival analysis. Different from
<a href="https://en.wikipedia.org/wiki/Proportional_hazards_model">Proportional hazards</a> model
designed for the same purpose, the AFT model is more easily to parallelize
because each instance contribute to the objective function independently.</p>
<p>Given the values of the covariates $x^{&#8216;}$, for random lifetime $t_{i}$ of
subjects i = 1, &#8230;, n, with possible right-censoring,
the likelihood function under the AFT model is given as:
<code>\[
L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}(\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}}
\]</code>
Where $\delta_{i}$ is the indicator of the event has occurred i.e. uncensored or not.
Using $\epsilon_{i}=\frac{\log{t_{i}}-x^{&#8216;}\beta}{\sigma}$, the log-likelihood function
assumes the form:
<code>\[
\iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+\delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}]
\]</code>
Where $S_{0}(\epsilon_{i})$ is the baseline survivor function,
and $f_{0}(\epsilon_{i})$ is corresponding density function.</p>
<p>The most commonly used AFT model is based on the Weibull distribution of the survival time.
The Weibull distribution for lifetime corresponding to extreme value distribution for
log of the lifetime, and the $S_{0}(\epsilon)$ function is:
<code>\[
S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}})
\]</code>
the $f_{0}(\epsilon_{i})$ function is:
<code>\[
f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}})
\]</code>
The log-likelihood function for AFT model with Weibull distribution of lifetime is:
<code>\[
\iota(\beta,\sigma)= -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}]
\]</code>
Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability,
the loss function we use to optimize is $-\iota(\beta,\sigma)$.
The gradient functions for $\beta$ and $\log\sigma$ respectively are:
<code>\[
\frac{\partial (-\iota)}{\partial \beta}=\sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma}
\]</code>
<code>\[
\frac{\partial (-\iota)}{\partial (\log\sigma)}=\sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}]
\]</code></p>
<p>The AFT model can be formulated as a convex optimization problem,
i.e. the task of finding a minimizer of a convex function $-\iota(\beta,\sigma)$
that depends coefficients vector $\beta$ and the log of scale parameter $\log\sigma$.
The optimization algorithm underlying the implementation is L-BFGS.
The implementation matches the result from R&#8217;s survival function
<a href="https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html">survreg</a></p>
<p><strong>Example</strong></p>
<div class="codetabs">
<div data-lang="scala">
<div class="highlight"><pre><span class="k">import</span> <span class="nn">org.apache.spark.ml.linalg.Vectors</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span>
<span class="k">val</span> <span class="n">training</span> <span class="k">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="o">(</span><span class="nc">Seq</span><span class="o">(</span>
<span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span>
<span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="nc">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span>
<span class="o">)).</span><span class="n">toDF</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;censor&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">quantileProbabilities</span> <span class="k">=</span> <span class="nc">Array</span><span class="o">(</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">)</span>
<span class="k">val</span> <span class="n">aft</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">AFTSurvivalRegression</span><span class="o">()</span>
<span class="o">.</span><span class="n">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span>
<span class="o">.</span><span class="n">setQuantilesCol</span><span class="o">(</span><span class="s">&quot;quantiles&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">training</span><span class="o">)</span>
<span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span>
<span class="n">println</span><span class="o">(</span><span class="n">s</span><span class="s">&quot;Coefficients: ${model.coefficients} Intercept: &quot;</span> <span class="o">+</span>
<span class="n">s</span><span class="s">&quot;${model.intercept} Scale: ${model.scale}&quot;</span><span class="o">)</span>
<span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="kc">false</span><span class="o">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala" in the Spark repo.</small></div>
</div>
<div data-lang="java">
<div class="highlight"><pre><span class="kn">import</span> <span class="nn">java.util.Arrays</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">java.util.List</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.AFTSurvivalRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Dataset</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.Row</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.RowFactory</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SparkSession</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.types.*</span><span class="o">;</span>
<span class="n">List</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">Arrays</span><span class="o">.</span><span class="na">asList</span><span class="o">(</span>
<span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">1.218</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.560</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="o">)),</span>
<span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">2.949</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.346</span><span class="o">,</span> <span class="mf">2.158</span><span class="o">)),</span>
<span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">3.627</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">1.380</span><span class="o">,</span> <span class="mf">0.231</span><span class="o">)),</span>
<span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">0.273</span><span class="o">,</span> <span class="mf">1.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.520</span><span class="o">,</span> <span class="mf">1.151</span><span class="o">)),</span>
<span class="n">RowFactory</span><span class="o">.</span><span class="na">create</span><span class="o">(</span><span class="mf">4.199</span><span class="o">,</span> <span class="mf">0.0</span><span class="o">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="na">dense</span><span class="o">(</span><span class="mf">0.795</span><span class="o">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="o">))</span>
<span class="o">);</span>
<span class="n">StructType</span> <span class="n">schema</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StructType</span><span class="o">(</span><span class="k">new</span> <span class="n">StructField</span><span class="o">[]{</span>
<span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span>
<span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">&quot;censor&quot;</span><span class="o">,</span> <span class="n">DataTypes</span><span class="o">.</span><span class="na">DoubleType</span><span class="o">,</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">()),</span>
<span class="k">new</span> <span class="nf">StructField</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">,</span> <span class="k">new</span> <span class="nf">VectorUDT</span><span class="o">(),</span> <span class="kc">false</span><span class="o">,</span> <span class="n">Metadata</span><span class="o">.</span><span class="na">empty</span><span class="o">())</span>
<span class="o">});</span>
<span class="n">Dataset</span><span class="o">&lt;</span><span class="n">Row</span><span class="o">&gt;</span> <span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">schema</span><span class="o">);</span>
<span class="kt">double</span><span class="o">[]</span> <span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="k">new</span> <span class="kt">double</span><span class="o">[]{</span><span class="mf">0.3</span><span class="o">,</span> <span class="mf">0.6</span><span class="o">};</span>
<span class="n">AFTSurvivalRegression</span> <span class="n">aft</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">AFTSurvivalRegression</span><span class="o">()</span>
<span class="o">.</span><span class="na">setQuantileProbabilities</span><span class="o">(</span><span class="n">quantileProbabilities</span><span class="o">)</span>
<span class="o">.</span><span class="na">setQuantilesCol</span><span class="o">(</span><span class="s">&quot;quantiles&quot;</span><span class="o">);</span>
<span class="n">AFTSurvivalRegressionModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">training</span><span class="o">);</span>
<span class="c1">// Print the coefficients, intercept and scale parameter for AFT survival regression</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Coefficients: &quot;</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">coefficients</span><span class="o">()</span> <span class="o">+</span> <span class="s">&quot; Intercept: &quot;</span>
<span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">intercept</span><span class="o">()</span> <span class="o">+</span> <span class="s">&quot; Scale: &quot;</span> <span class="o">+</span> <span class="n">model</span><span class="o">.</span><span class="na">scale</span><span class="o">());</span>
<span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">training</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="kc">false</span><span class="o">);</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/java/org/apache/spark/examples/ml/JavaAFTSurvivalRegressionExample.java" in the Spark repo.</small></div>
</div>
<div data-lang="python">
<div class="highlight"><pre><span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">AFTSurvivalRegression</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.linalg</span> <span class="kn">import</span> <span class="n">Vectors</span>
<span class="n">training</span> <span class="o">=</span> <span class="n">spark</span><span class="o">.</span><span class="n">createDataFrame</span><span class="p">([</span>
<span class="p">(</span><span class="mf">1.218</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.560</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.605</span><span class="p">)),</span>
<span class="p">(</span><span class="mf">2.949</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.346</span><span class="p">,</span> <span class="mf">2.158</span><span class="p">)),</span>
<span class="p">(</span><span class="mf">3.627</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">1.380</span><span class="p">,</span> <span class="mf">0.231</span><span class="p">)),</span>
<span class="p">(</span><span class="mf">0.273</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.520</span><span class="p">,</span> <span class="mf">1.151</span><span class="p">)),</span>
<span class="p">(</span><span class="mf">4.199</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">Vectors</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="mf">0.795</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.226</span><span class="p">))],</span> <span class="p">[</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;censor&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">])</span>
<span class="n">quantileProbabilities</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">]</span>
<span class="n">aft</span> <span class="o">=</span> <span class="n">AFTSurvivalRegression</span><span class="p">(</span><span class="n">quantileProbabilities</span><span class="o">=</span><span class="n">quantileProbabilities</span><span class="p">,</span>
<span class="n">quantilesCol</span><span class="o">=</span><span class="s">&quot;quantiles&quot;</span><span class="p">)</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">aft</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">training</span><span class="p">)</span>
<span class="c"># Print the coefficients, intercept and scale parameter for AFT survival regression</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Coefficients: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">coefficients</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Intercept: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">intercept</span><span class="p">))</span>
<span class="k">print</span><span class="p">(</span><span class="s">&quot;Scale: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">scale</span><span class="p">))</span>
<span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">training</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="n">truncate</span><span class="o">=</span><span class="bp">False</span><span class="p">)</span>
</pre></div>
<div><small>Find full example code at "examples/src/main/python/ml/aft_survival_regression.py" in the Spark repo.</small></div>
</div>
</div>
<h1 id="decision-trees">Decision trees</h1>
<p><a href="http://en.wikipedia.org/wiki/Decision_tree_learning">Decision trees</a>
and their ensembles are popular methods for the machine learning tasks of
classification and regression. Decision trees are widely used since they are easy to interpret,
handle categorical features, extend to the multiclass classification setting, do not require
feature scaling, and are able to capture non-linearities and feature interactions. Tree ensemble
algorithms such as random forests and boosting are among the top performers for classification and
regression tasks.</p>
<p>The <code>spark.ml</code> implementation supports decision trees for binary and multiclass classification and for regression,
using both continuous and categorical features. The implementation partitions data by rows,
allowing distributed training with millions or even billions of instances.</p>
<p>Users can find more information about the decision tree algorithm in the <a href="mllib-decision-tree.html">MLlib Decision Tree guide</a>.
The main differences between this API and the <a href="mllib-decision-tree.html">original MLlib Decision Tree API</a> are:</p>
<ul>
<li>support for ML Pipelines</li>
<li>separation of Decision Trees for classification vs. regression</li>
<li>use of DataFrame metadata to distinguish continuous and categorical features</li>
</ul>
<p>The Pipelines API for Decision Trees offers a bit more functionality than the original API.<br />
In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities);
for regression, users can get the biased sample variance of prediction.</p>
<p>Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the <a href="#tree-ensembles">Tree ensembles section</a>.</p>
<h2 id="inputs-and-outputs">Inputs and Outputs</h2>
<p>We list the input and output (prediction) column types here.
All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p>
<h3 id="input-columns">Input Columns</h3>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>labelCol</td>
<td>Double</td>
<td>"label"</td>
<td>Label to predict</td>
</tr>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>
<h3 id="output-columns">Output Columns</h3>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
<th align="left">Notes</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Double</td>
<td>"prediction"</td>
<td>Predicted label</td>
<td></td>
</tr>
<tr>
<td>rawPredictionCol</td>
<td>Vector</td>
<td>"rawPrediction"</td>
<td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td>
<td>Classification only</td>
</tr>
<tr>
<td>probabilityCol</td>
<td>Vector</td>
<td>"probability"</td>
<td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td>
<td>Classification only</td>
</tr>
<tr>
<td>varianceCol</td>
<td>Double</td>
<td></td>
<td>The biased sample variance of prediction</td>
<td>Regression only</td>
</tr>
</tbody>
</table>
<h1 id="tree-ensembles">Tree Ensembles</h1>
<p>The DataFrame API supports two major tree ensemble algorithms: <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forests</a> and <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>.
Both use <a href="ml-classification-regression.html#decision-trees"><code>spark.ml</code> decision trees</a> as their base models.</p>
<p>Users can find more information about ensemble algorithms in the <a href="mllib-ensembles.html">MLlib Ensemble guide</a>.<br />
In this section, we demonstrate the DataFrame API for ensembles.</p>
<p>The main differences between this API and the <a href="mllib-ensembles.html">original MLlib ensembles API</a> are:</p>
<ul>
<li>support for DataFrames and ML Pipelines</li>
<li>separation of classification vs. regression</li>
<li>use of DataFrame metadata to distinguish continuous and categorical features</li>
<li>more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.</li>
</ul>
<h2 id="random-forests">Random Forests</h2>
<p><a href="http://en.wikipedia.org/wiki/Random_forest">Random forests</a>
are ensembles of <a href="ml-decision-tree.html">decision trees</a>.
Random forests combine many decision trees in order to reduce the risk of overfitting.
The <code>spark.ml</code> implementation supports random forests for binary and multiclass classification and for regression,
using both continuous and categorical features.</p>
<p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html"><code>spark.mllib</code> documentation on random forests</a>.</p>
<h3 id="inputs-and-outputs-1">Inputs and Outputs</h3>
<p>We list the input and output (prediction) column types here.
All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p>
<h4 id="input-columns-1">Input Columns</h4>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>labelCol</td>
<td>Double</td>
<td>"label"</td>
<td>Label to predict</td>
</tr>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>
<h4 id="output-columns-predictions">Output Columns (Predictions)</h4>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
<th align="left">Notes</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Double</td>
<td>"prediction"</td>
<td>Predicted label</td>
<td></td>
</tr>
<tr>
<td>rawPredictionCol</td>
<td>Vector</td>
<td>"rawPrediction"</td>
<td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td>
<td>Classification only</td>
</tr>
<tr>
<td>probabilityCol</td>
<td>Vector</td>
<td>"probability"</td>
<td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td>
<td>Classification only</td>
</tr>
</tbody>
</table>
<h2 id="gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</h2>
<p><a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>
are ensembles of <a href="ml-decision-tree.html">decision trees</a>.
GBTs iteratively train decision trees in order to minimize a loss function.
The <code>spark.ml</code> implementation supports GBTs for binary classification and for regression,
using both continuous and categorical features.</p>
<p>For more information on the algorithm itself, please see the <a href="mllib-ensembles.html"><code>spark.mllib</code> documentation on GBTs</a>.</p>
<h3 id="inputs-and-outputs-2">Inputs and Outputs</h3>
<p>We list the input and output (prediction) column types here.
All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p>
<h4 id="input-columns-2">Input Columns</h4>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>labelCol</td>
<td>Double</td>
<td>"label"</td>
<td>Label to predict</td>
</tr>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>
<p>Note that <code>GBTClassifier</code> currently only supports binary labels.</p>
<h4 id="output-columns-predictions-1">Output Columns (Predictions)</h4>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
<th align="left">Notes</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Double</td>
<td>"prediction"</td>
<td>Predicted label</td>
<td></td>
</tr>
</tbody>
</table>
<p>In the future, <code>GBTClassifier</code> will also output columns for <code>rawPrediction</code> and <code>probability</code>, just as <code>RandomForestClassifier</code> does.</p>
</div>
<!-- /container -->
</div>
<script src="js/vendor/jquery-1.8.0.min.js"></script>
<script src="js/vendor/bootstrap.min.js"></script>
<script src="js/vendor/anchor.min.js"></script>
<script src="js/main.js"></script>
<!-- MathJax Section -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
TeX: { equationNumbers: { autoNumber: "AMS" } }
});
</script>
<script>
// Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS.
// We could use "//cdn.mathjax...", but that won't support "file://".
(function(d, script) {
script = d.createElement('script');
script.type = 'text/javascript';
script.async = true;
script.onload = function(){
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ],
displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre']
}
});
};
script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') +
'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
d.getElementsByTagName('head')[0].appendChild(script);
}(document));
</script>
</body>
</html>