blob: a53f5b8f311fac7be3f2b55f48fa89b7e914e163 [file] [log] [blame]
<!DOCTYPE html>
<!--[if lt IE 7]> <html class="no-js lt-ie9 lt-ie8 lt-ie7"> <![endif]-->
<!--[if IE 7]> <html class="no-js lt-ie9 lt-ie8"> <![endif]-->
<!--[if IE 8]> <html class="no-js lt-ie9"> <![endif]-->
<!--[if gt IE 8]><!--> <html class="no-js"> <!--<![endif]-->
<head>
<meta charset="utf-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge,chrome=1">
<title>Ensembles - Spark 1.5.2 Documentation</title>
<link rel="stylesheet" href="css/bootstrap.min.css">
<style>
body {
padding-top: 60px;
padding-bottom: 40px;
}
</style>
<meta name="viewport" content="width=device-width">
<link rel="stylesheet" href="css/bootstrap-responsive.min.css">
<link rel="stylesheet" href="css/main.css">
<script src="js/vendor/modernizr-2.6.1-respond-1.1.0.min.js"></script>
<link rel="stylesheet" href="css/pygments-default.css">
<!-- Google analytics script -->
<script type="text/javascript">
var _gaq = _gaq || [];
_gaq.push(['_setAccount', 'UA-32518208-2']);
_gaq.push(['_trackPageview']);
(function() {
var ga = document.createElement('script'); ga.type = 'text/javascript'; ga.async = true;
ga.src = ('https:' == document.location.protocol ? 'https://ssl' : 'http://www') + '.google-analytics.com/ga.js';
var s = document.getElementsByTagName('script')[0]; s.parentNode.insertBefore(ga, s);
})();
</script>
</head>
<body>
<!--[if lt IE 7]>
<p class="chromeframe">You are using an outdated browser. <a href="http://browsehappy.com/">Upgrade your browser today</a> or <a href="http://www.google.com/chromeframe/?redirect=true">install Google Chrome Frame</a> to better experience this site.</p>
<![endif]-->
<!-- This code is taken from http://twitter.github.com/bootstrap/examples/hero.html -->
<div class="navbar navbar-fixed-top" id="topbar">
<div class="navbar-inner">
<div class="container">
<div class="brand"><a href="index.html">
<img src="img/spark-logo-hd.png" style="height:50px;"/></a><span class="version">1.5.2</span>
</div>
<ul class="nav">
<!--TODO(andyk): Add class="active" attribute to li some how.-->
<li><a href="index.html">Overview</a></li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Programming Guides<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="quick-start.html">Quick Start</a></li>
<li><a href="programming-guide.html">Spark Programming Guide</a></li>
<li class="divider"></li>
<li><a href="streaming-programming-guide.html">Spark Streaming</a></li>
<li><a href="sql-programming-guide.html">DataFrames and SQL</a></li>
<li><a href="mllib-guide.html">MLlib (Machine Learning)</a></li>
<li><a href="graphx-programming-guide.html">GraphX (Graph Processing)</a></li>
<li><a href="bagel-programming-guide.html">Bagel (Pregel on Spark)</a></li>
<li><a href="sparkr.html">SparkR (R on Spark)</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">API Docs<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="api/scala/index.html#org.apache.spark.package">Scala</a></li>
<li><a href="api/java/index.html">Java</a></li>
<li><a href="api/python/index.html">Python</a></li>
<li><a href="api/R/index.html">R</a></li>
</ul>
</li>
<li class="dropdown">
<a href="#" class="dropdown-toggle" data-toggle="dropdown">Deploying<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="cluster-overview.html">Overview</a></li>
<li><a href="submitting-applications.html">Submitting Applications</a></li>
<li class="divider"></li>
<li><a href="spark-standalone.html">Spark Standalone</a></li>
<li><a href="running-on-mesos.html">Mesos</a></li>
<li><a href="running-on-yarn.html">YARN</a></li>
<li class="divider"></li>
<li><a href="ec2-scripts.html">Amazon EC2</a></li>
</ul>
</li>
<li class="dropdown">
<a href="api.html" class="dropdown-toggle" data-toggle="dropdown">More<b class="caret"></b></a>
<ul class="dropdown-menu">
<li><a href="configuration.html">Configuration</a></li>
<li><a href="monitoring.html">Monitoring</a></li>
<li><a href="tuning.html">Tuning Guide</a></li>
<li><a href="job-scheduling.html">Job Scheduling</a></li>
<li><a href="security.html">Security</a></li>
<li><a href="hardware-provisioning.html">Hardware Provisioning</a></li>
<li><a href="hadoop-third-party-distributions.html">3<sup>rd</sup>-Party Hadoop Distros</a></li>
<li class="divider"></li>
<li><a href="building-spark.html">Building Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark">Contributing to Spark</a></li>
<li><a href="https://cwiki.apache.org/confluence/display/SPARK/Supplemental+Spark+Projects">Supplemental Projects</a></li>
</ul>
</li>
</ul>
<!--<p class="navbar-text pull-right"><span class="version-text">v1.5.2</span></p>-->
</div>
</div>
</div>
<div class="container" id="content">
<h1 class="title"><a href="ml-guide.html">ML</a> - Ensembles</h1>
<p><strong>Table of Contents</strong></p>
<ul id="markdown-toc">
<li><a href="#tree-ensembles" id="markdown-toc-tree-ensembles">Tree Ensembles</a> <ul>
<li><a href="#random-forests" id="markdown-toc-random-forests">Random Forests</a> <ul>
<li><a href="#inputs-and-outputs" id="markdown-toc-inputs-and-outputs">Inputs and Outputs</a> <ul>
<li><a href="#input-columns" id="markdown-toc-input-columns">Input Columns</a></li>
<li><a href="#output-columns-predictions" id="markdown-toc-output-columns-predictions">Output Columns (Predictions)</a></li>
</ul>
</li>
<li><a href="#example-classification" id="markdown-toc-example-classification">Example: Classification</a></li>
<li><a href="#example-regression" id="markdown-toc-example-regression">Example: Regression</a></li>
</ul>
</li>
<li><a href="#gradient-boosted-trees-gbts" id="markdown-toc-gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</a> <ul>
<li><a href="#inputs-and-outputs-1" id="markdown-toc-inputs-and-outputs-1">Inputs and Outputs</a> <ul>
<li><a href="#input-columns-1" id="markdown-toc-input-columns-1">Input Columns</a></li>
<li><a href="#output-columns-predictions-1" id="markdown-toc-output-columns-predictions-1">Output Columns (Predictions)</a></li>
</ul>
</li>
<li><a href="#example-classification-1" id="markdown-toc-example-classification-1">Example: Classification</a></li>
<li><a href="#example-regression-1" id="markdown-toc-example-regression-1">Example: Regression</a></li>
</ul>
</li>
</ul>
</li>
<li><a href="#one-vs-rest-aka-one-vs-all" id="markdown-toc-one-vs-rest-aka-one-vs-all">One-vs-Rest (a.k.a. One-vs-All)</a> <ul>
<li><a href="#example" id="markdown-toc-example">Example</a></li>
</ul>
</li>
</ul>
<p>An <a href="http://en.wikipedia.org/wiki/Ensemble_learning">ensemble method</a>
is a learning algorithm which creates a model composed of a set of other base models.</p>
<h2 id="tree-ensembles">Tree Ensembles</h2>
<p>The Pipelines API supports two major tree ensemble algorithms: <a href="http://en.wikipedia.org/wiki/Random_forest">Random Forests</a> and <a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>.
Both use <a href="ml-decision-tree.html">MLlib decision trees</a> as their base models.</p>
<p>Users can find more information about ensemble algorithms in the <a href="mllib-ensembles.html">MLlib Ensemble guide</a>. In this section, we demonstrate the Pipelines API for ensembles.</p>
<p>The main differences between this API and the <a href="mllib-ensembles.html">original MLlib ensembles API</a> are:
* support for ML Pipelines
* separation of classification vs. regression
* use of DataFrame metadata to distinguish continuous and categorical features
* a bit more functionality for random forests: estimates of feature importance, as well as the predicted probability of each class (a.k.a. class conditional probabilities) for classification.</p>
<h3 id="random-forests">Random Forests</h3>
<p><a href="http://en.wikipedia.org/wiki/Random_forest">Random forests</a>
are ensembles of <a href="ml-decision-tree.html">decision trees</a>.
Random forests combine many decision trees in order to reduce the risk of overfitting.
MLlib supports random forests for binary and multiclass classification and for regression,
using both continuous and categorical features.</p>
<p>This section gives examples of using random forests with the Pipelines API.
For more information on the algorithm, please see the <a href="mllib-ensembles.html">main MLlib docs on random forests</a>.</p>
<h4 id="inputs-and-outputs">Inputs and Outputs</h4>
<p>We list the input and output (prediction) column types here.
All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p>
<h5 id="input-columns">Input Columns</h5>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>labelCol</td>
<td>Double</td>
<td>"label"</td>
<td>Label to predict</td>
</tr>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>
<h5 id="output-columns-predictions">Output Columns (Predictions)</h5>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
<th align="left">Notes</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Double</td>
<td>"prediction"</td>
<td>Predicted label</td>
<td></td>
</tr>
<tr>
<td>rawPredictionCol</td>
<td>Vector</td>
<td>"rawPrediction"</td>
<td>Vector of length # classes, with the counts of training instance labels at the tree node which makes the prediction</td>
<td>Classification only</td>
</tr>
<tr>
<td>probabilityCol</td>
<td>Vector</td>
<td>"probability"</td>
<td>Vector of length # classes equal to rawPrediction normalized to a multinomial distribution</td>
<td>Classification only</td>
</tr>
</tbody>
</table>
<h4 id="example-classification">Example: Classification</h4>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.RandomForestClassifier">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">).</span><span class="n">toDF</span><span class="o">()</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setNumTrees</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span>
<span class="c1">// Chain indexers and forest in a Pipeline</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span>
<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestClassificationModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned classification forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span></code></pre></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/RandomForestClassifier.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.RandomForestClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.rdd.RDD</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">RDD</span><span class="o">&lt;</span><span class="n">LabeledPoint</span><span class="o">&gt;</span> <span class="n">rdd</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">rdd</span><span class="o">,</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="n">RandomForestClassifier</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span>
<span class="c1">// Chain indexers and forest in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span>
<span class="n">RandomForestClassificationModel</span> <span class="n">rfModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">RandomForestClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned classification forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.RandomForestClassifier">Python API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">RandomForestClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Index labels, adding metadata to the label column.</span>
<span class="c"># Fit on whole dataset to include all labels in index.</span>
<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a RandomForest model.</span>
<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexers and forest in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexers.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span> <span class="s">&quot;Test Error = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">)</span>
<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">print</span> <span class="n">rfModel</span> <span class="c"># summary only</span></code></pre></div>
</div>
</div>
<h4 id="example-regression">Example: Regression</h4>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use a feature transformer to index categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.RandomForestRegressor">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">).</span><span class="n">toDF</span><span class="o">()</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="k">val</span> <span class="n">rf</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RandomForestRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="c1">// Chain indexer and forest in a Pipeline</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rfModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">RandomForestRegressionModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned regression forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span></code></pre></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/regression/RandomForestRegressor.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.RandomForestRegressor</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.rdd.RDD</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">RDD</span><span class="o">&lt;</span><span class="n">LabeledPoint</span><span class="o">&gt;</span> <span class="n">rdd</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">rdd</span><span class="o">,</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a RandomForest model.</span>
<span class="n">RandomForestRegressor</span> <span class="n">rf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RandomForestRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">);</span>
<span class="c1">// Chain indexer and forest in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">rf</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span>
<span class="n">RandomForestRegressionModel</span> <span class="n">rfModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">RandomForestRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned regression forest model:\n&quot;</span> <span class="o">+</span> <span class="n">rfModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.RandomForestRegressor">Python API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">RandomForestRegressor</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a RandomForest model.</span>
<span class="n">rf</span> <span class="o">=</span> <span class="n">RandomForestRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">)</span>
<span class="c"># Chain indexer and forest in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">rf</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexer.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;rmse&quot;</span><span class="p">)</span>
<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span> <span class="s">&quot;Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="n">rmse</span>
<span class="n">rfModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span> <span class="n">rfModel</span> <span class="c"># summary only</span></code></pre></div>
</div>
</div>
<h3 id="gradient-boosted-trees-gbts">Gradient-Boosted Trees (GBTs)</h3>
<p><a href="http://en.wikipedia.org/wiki/Gradient_boosting">Gradient-Boosted Trees (GBTs)</a>
are ensembles of <a href="ml-decision-tree.html">decision trees</a>.
GBTs iteratively train decision trees in order to minimize a loss function.
MLlib supports GBTs for binary classification and for regression,
using both continuous and categorical features.</p>
<p>This section gives examples of using GBTs with the Pipelines API.
For more information on the algorithm, please see the <a href="mllib-ensembles.html">main MLlib docs on GBTs</a>.</p>
<h4 id="inputs-and-outputs-1">Inputs and Outputs</h4>
<p>We list the input and output (prediction) column types here.
All output columns are optional; to exclude an output column, set its corresponding Param to an empty string.</p>
<h5 id="input-columns-1">Input Columns</h5>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
</tr>
</thead>
<tbody>
<tr>
<td>labelCol</td>
<td>Double</td>
<td>"label"</td>
<td>Label to predict</td>
</tr>
<tr>
<td>featuresCol</td>
<td>Vector</td>
<td>"features"</td>
<td>Feature vector</td>
</tr>
</tbody>
</table>
<p>Note that <code>GBTClassifier</code> currently only supports binary labels.</p>
<h5 id="output-columns-predictions-1">Output Columns (Predictions)</h5>
<table class="table">
<thead>
<tr>
<th align="left">Param name</th>
<th align="left">Type(s)</th>
<th align="left">Default</th>
<th align="left">Description</th>
<th align="left">Notes</th>
</tr>
</thead>
<tbody>
<tr>
<td>predictionCol</td>
<td>Double</td>
<td>"prediction"</td>
<td>Predicted label</td>
<td></td>
</tr>
</tbody>
</table>
<p>In the future, <code>GBTClassifier</code> will also output columns for <code>rawPrediction</code> and <code>probability</code>, just as <code>RandomForestClassifier</code> does.</p>
<h4 id="example-classification-1">Example: Classification</h4>
<p>The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set.
We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the <code>DataFrame</code> which the tree-based algorithms can recognize.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classification.GBTClassifier">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.</span><span class="o">{</span><span class="nc">StringIndexer</span><span class="o">,</span> <span class="nc">IndexToString</span><span class="o">,</span> <span class="nc">VectorIndexer</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">).</span><span class="n">toDF</span><span class="o">()</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="k">val</span> <span class="n">labelIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a GBT model.</span>
<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="k">val</span> <span class="n">labelConverter</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="n">labels</span><span class="o">)</span>
<span class="c1">// Chain indexers and GBT in a Pipeline</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">accuracy</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">))</span>
<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">2</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTClassificationModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned classification GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span></code></pre></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/GBTClassifier.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassifier</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.GBTClassificationModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.*</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.rdd.RDD</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">RDD</span><span class="o">&lt;</span><span class="n">LabeledPoint</span><span class="o">&gt;</span> <span class="n">rdd</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">rdd</span><span class="o">,</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Index labels, adding metadata to the label column.</span>
<span class="c1">// Fit on whole dataset to include all labels in index.</span>
<span class="n">StringIndexerModel</span> <span class="n">labelIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">StringIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a GBT model.</span>
<span class="n">GBTClassifier</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTClassifier</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span>
<span class="c1">// Convert indexed labels back to original labels.</span>
<span class="n">IndexToString</span> <span class="n">labelConverter</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">IndexToString</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setLabels</span><span class="o">(</span><span class="n">labelIndexer</span><span class="o">.</span><span class="na">labels</span><span class="o">());</span>
<span class="c1">// Chain indexers and GBT in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">labelIndexer</span><span class="o">,</span> <span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">,</span> <span class="n">labelConverter</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexers.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;predictedLabel&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">MulticlassClassificationEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassClassificationEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;indexedLabel&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;precision&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Test Error = &quot;</span> <span class="o">+</span> <span class="o">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="o">));</span>
<span class="n">GBTClassificationModel</span> <span class="n">gbtModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">GBTClassificationModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">2</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned classification GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.classification.GBTClassifier">Python API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.classification</span> <span class="kn">import</span> <span class="n">GBTClassifier</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">StringIndexer</span><span class="p">,</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">MulticlassClassificationEvaluator</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Index labels, adding metadata to the label column.</span>
<span class="c"># Fit on whole dataset to include all labels in index.</span>
<span class="n">labelIndexer</span> <span class="o">=</span> <span class="n">StringIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a GBT model.</span>
<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTClassifier</span><span class="p">(</span><span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c"># Chain indexers and GBT in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">labelIndexer</span><span class="p">,</span> <span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexers.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">MulticlassClassificationEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;indexedLabel&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;precision&quot;</span><span class="p">)</span>
<span class="n">accuracy</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span> <span class="s">&quot;Test Error = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">accuracy</span><span class="p">)</span>
<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>
<span class="k">print</span> <span class="n">gbtModel</span> <span class="c"># summary only</span></code></pre></div>
</div>
</div>
<h4 id="example-regression-1">Example: Regression</h4>
<p>Note: For this example dataset, <code>GBTRegressor</code> actually only needs 1 iteration, but that will not
be true in general.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.regression.GBTRegressor">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span>
<span class="k">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">).</span><span class="n">toDF</span><span class="o">()</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="k">val</span> <span class="n">featureIndexer</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="n">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">data</span><span class="o">)</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">trainingData</span><span class="o">,</span> <span class="n">testData</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// Train a GBT model.</span>
<span class="k">val</span> <span class="n">gbt</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">GBTRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">)</span>
<span class="c1">// Chain indexer and GBT in a Pipeline</span>
<span class="k">val</span> <span class="n">pipeline</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="n">setStages</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">))</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="k">val</span> <span class="n">model</span> <span class="k">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">)</span>
<span class="c1">// Make predictions.</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">)</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="n">show</span><span class="o">(</span><span class="mi">5</span><span class="o">)</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="k">val</span> <span class="n">evaluator</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="n">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="n">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">rmse</span> <span class="k">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">)</span>
<span class="k">val</span> <span class="n">gbtModel</span> <span class="k">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="o">(</span><span class="mi">1</span><span class="o">).</span><span class="n">asInstanceOf</span><span class="o">[</span><span class="kt">GBTRegressionModel</span><span class="o">]</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;Learned regression GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="n">toDebugString</span><span class="o">)</span></code></pre></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/regression/GBTRegressor.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">org.apache.spark.ml.Pipeline</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.PipelineStage</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.evaluation.RegressionEvaluator</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexer</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.feature.VectorIndexerModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressionModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.regression.GBTRegressor</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.rdd.RDD</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="c1">// Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">RDD</span><span class="o">&lt;</span><span class="n">LabeledPoint</span><span class="o">&gt;</span> <span class="n">rdd</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">data</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">rdd</span><span class="o">,</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="c1">// Automatically identify categorical features, and index them.</span>
<span class="c1">// Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">VectorIndexerModel</span> <span class="n">featureIndexer</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">VectorIndexer</span><span class="o">()</span>
<span class="o">.</span><span class="na">setInputCol</span><span class="o">(</span><span class="s">&quot;features&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setOutputCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxCategories</span><span class="o">(</span><span class="mi">4</span><span class="o">)</span>
<span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">data</span><span class="o">);</span>
<span class="c1">// Split the data into training and test sets (30% held out for testing)</span>
<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">});</span>
<span class="n">DataFrame</span> <span class="n">trainingData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">DataFrame</span> <span class="n">testData</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// Train a GBT model.</span>
<span class="n">GBTRegressor</span> <span class="n">gbt</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">GBTRegressor</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setFeaturesCol</span><span class="o">(</span><span class="s">&quot;indexedFeatures&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMaxIter</span><span class="o">(</span><span class="mi">10</span><span class="o">);</span>
<span class="c1">// Chain indexer and GBT in a Pipeline</span>
<span class="n">Pipeline</span> <span class="n">pipeline</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">Pipeline</span><span class="o">()</span>
<span class="o">.</span><span class="na">setStages</span><span class="o">(</span><span class="k">new</span> <span class="n">PipelineStage</span><span class="o">[]</span> <span class="o">{</span><span class="n">featureIndexer</span><span class="o">,</span> <span class="n">gbt</span><span class="o">});</span>
<span class="c1">// Train model. This also runs the indexer.</span>
<span class="n">PipelineModel</span> <span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">trainingData</span><span class="o">);</span>
<span class="c1">// Make predictions.</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">testData</span><span class="o">);</span>
<span class="c1">// Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">,</span> <span class="s">&quot;features&quot;</span><span class="o">).</span><span class="na">show</span><span class="o">(</span><span class="mi">5</span><span class="o">);</span>
<span class="c1">// Select (prediction, true label) and compute test error</span>
<span class="n">RegressionEvaluator</span> <span class="n">evaluator</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">RegressionEvaluator</span><span class="o">()</span>
<span class="o">.</span><span class="na">setLabelCol</span><span class="o">(</span><span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setPredictionCol</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">)</span>
<span class="o">.</span><span class="na">setMetricName</span><span class="o">(</span><span class="s">&quot;rmse&quot;</span><span class="o">);</span>
<span class="kt">double</span> <span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="na">evaluate</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Root Mean Squared Error (RMSE) on test data = &quot;</span> <span class="o">+</span> <span class="n">rmse</span><span class="o">);</span>
<span class="n">GBTRegressionModel</span> <span class="n">gbtModel</span> <span class="o">=</span>
<span class="o">(</span><span class="n">GBTRegressionModel</span><span class="o">)(</span><span class="n">model</span><span class="o">.</span><span class="na">stages</span><span class="o">()[</span><span class="mi">1</span><span class="o">]);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Learned regression GBT model:\n&quot;</span> <span class="o">+</span> <span class="n">gbtModel</span><span class="o">.</span><span class="na">toDebugString</span><span class="o">());</span></code></pre></div>
</div>
<div data-lang="python">
<p>Refer to the <a href="api/python/pyspark.ml.html#pyspark.ml.regression.GBTRegressor">Python API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">pyspark.ml</span> <span class="kn">import</span> <span class="n">Pipeline</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.regression</span> <span class="kn">import</span> <span class="n">GBTRegressor</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.feature</span> <span class="kn">import</span> <span class="n">VectorIndexer</span>
<span class="kn">from</span> <span class="nn">pyspark.ml.evaluation</span> <span class="kn">import</span> <span class="n">RegressionEvaluator</span>
<span class="kn">from</span> <span class="nn">pyspark.mllib.util</span> <span class="kn">import</span> <span class="n">MLUtils</span>
<span class="c"># Load and parse the data file, converting it to a DataFrame.</span>
<span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="p">(</span><span class="n">sc</span><span class="p">,</span> <span class="s">&quot;data/mllib/sample_libsvm_data.txt&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">toDF</span><span class="p">()</span>
<span class="c"># Automatically identify categorical features, and index them.</span>
<span class="c"># Set maxCategories so features with &gt; 4 distinct values are treated as continuous.</span>
<span class="n">featureIndexer</span> <span class="o">=</span>\
<span class="n">VectorIndexer</span><span class="p">(</span><span class="n">inputCol</span><span class="o">=</span><span class="s">&quot;features&quot;</span><span class="p">,</span> <span class="n">outputCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxCategories</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">data</span><span class="p">)</span>
<span class="c"># Split the data into training and test sets (30% held out for testing)</span>
<span class="p">(</span><span class="n">trainingData</span><span class="p">,</span> <span class="n">testData</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">randomSplit</span><span class="p">([</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">])</span>
<span class="c"># Train a GBT model.</span>
<span class="n">gbt</span> <span class="o">=</span> <span class="n">GBTRegressor</span><span class="p">(</span><span class="n">featuresCol</span><span class="o">=</span><span class="s">&quot;indexedFeatures&quot;</span><span class="p">,</span> <span class="n">maxIter</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="c"># Chain indexer and GBT in a Pipeline</span>
<span class="n">pipeline</span> <span class="o">=</span> <span class="n">Pipeline</span><span class="p">(</span><span class="n">stages</span><span class="o">=</span><span class="p">[</span><span class="n">featureIndexer</span><span class="p">,</span> <span class="n">gbt</span><span class="p">])</span>
<span class="c"># Train model. This also runs the indexer.</span>
<span class="n">model</span> <span class="o">=</span> <span class="n">pipeline</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">trainingData</span><span class="p">)</span>
<span class="c"># Make predictions.</span>
<span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">testData</span><span class="p">)</span>
<span class="c"># Select example rows to display.</span>
<span class="n">predictions</span><span class="o">.</span><span class="n">select</span><span class="p">(</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="s">&quot;features&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">show</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span>
<span class="c"># Select (prediction, true label) and compute test error</span>
<span class="n">evaluator</span> <span class="o">=</span> <span class="n">RegressionEvaluator</span><span class="p">(</span>
<span class="n">labelCol</span><span class="o">=</span><span class="s">&quot;label&quot;</span><span class="p">,</span> <span class="n">predictionCol</span><span class="o">=</span><span class="s">&quot;prediction&quot;</span><span class="p">,</span> <span class="n">metricName</span><span class="o">=</span><span class="s">&quot;rmse&quot;</span><span class="p">)</span>
<span class="n">rmse</span> <span class="o">=</span> <span class="n">evaluator</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span>
<span class="k">print</span> <span class="s">&quot;Root Mean Squared Error (RMSE) on test data = </span><span class="si">%g</span><span class="s">&quot;</span> <span class="o">%</span> <span class="n">rmse</span>
<span class="n">gbtModel</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stages</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span>
<span class="k">print</span> <span class="n">gbtModel</span> <span class="c"># summary only</span></code></pre></div>
</div>
</div>
<h2 id="one-vs-rest-aka-one-vs-all">One-vs-Rest (a.k.a. One-vs-All)</h2>
<p><a href="http://en.wikipedia.org/wiki/Multiclass_classification#One-vs.-rest">OneVsRest</a> is an example of a machine learning reduction for performing multiclass classification given a base classifier that can perform binary classification efficiently. It is also known as &#8220;One-vs-All.&#8221;</p>
<p><code>OneVsRest</code> is implemented as an <code>Estimator</code>. For the base classifier it takes instances of <code>Classifier</code> and creates a binary classification problem for each of the k classes. The classifier for class i is trained to predict whether the label is i or not, distinguishing class i from all other classes.</p>
<p>Predictions are done by evaluating each binary classifier and the index of the most confident classifier is output as label.</p>
<h3 id="example">Example</h3>
<p>The example below demonstrates how to load the
<a href="http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/iris.scale">Iris dataset</a>, parse it as a DataFrame and perform multiclass classification using <code>OneVsRest</code>. The test error is calculated to measure the algorithm accuracy.</p>
<div class="codetabs">
<div data-lang="scala">
<p>Refer to the <a href="api/scala/index.html#org.apache.spark.ml.classifier.OneVsRest">Scala API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-scala" data-lang="scala"><span class="k">import</span> <span class="nn">org.apache.spark.ml.classification.</span><span class="o">{</span><span class="nc">LogisticRegression</span><span class="o">,</span> <span class="nc">OneVsRest</span><span class="o">}</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span>
<span class="k">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span>
<span class="k">import</span> <span class="nn">org.apache.spark.sql.</span><span class="o">{</span><span class="nc">Row</span><span class="o">,</span> <span class="nc">SQLContext</span><span class="o">}</span>
<span class="k">val</span> <span class="n">sqlContext</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">SQLContext</span><span class="o">(</span><span class="n">sc</span><span class="o">)</span>
<span class="c1">// parse data into dataframe</span>
<span class="k">val</span> <span class="n">data</span> <span class="k">=</span> <span class="nc">MLUtils</span><span class="o">.</span><span class="n">loadLibSVMFile</span><span class="o">(</span><span class="n">sc</span><span class="o">,</span>
<span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="nc">Array</span><span class="o">(</span><span class="n">train</span><span class="o">,</span> <span class="n">test</span><span class="o">)</span> <span class="k">=</span> <span class="n">data</span><span class="o">.</span><span class="n">toDF</span><span class="o">().</span><span class="n">randomSplit</span><span class="o">(</span><span class="nc">Array</span><span class="o">(</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">))</span>
<span class="c1">// instantiate multiclass learner and train</span>
<span class="k">val</span> <span class="n">ovr</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">OneVsRest</span><span class="o">().</span><span class="n">setClassifier</span><span class="o">(</span><span class="k">new</span> <span class="nc">LogisticRegression</span><span class="o">)</span>
<span class="k">val</span> <span class="n">ovrModel</span> <span class="k">=</span> <span class="n">ovr</span><span class="o">.</span><span class="n">fit</span><span class="o">(</span><span class="n">train</span><span class="o">)</span>
<span class="c1">// score model on test data</span>
<span class="k">val</span> <span class="n">predictions</span> <span class="k">=</span> <span class="n">ovrModel</span><span class="o">.</span><span class="n">transform</span><span class="o">(</span><span class="n">test</span><span class="o">).</span><span class="n">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">)</span>
<span class="k">val</span> <span class="n">predictionsAndLabels</span> <span class="k">=</span> <span class="n">predictions</span><span class="o">.</span><span class="n">map</span> <span class="o">{</span><span class="k">case</span> <span class="nc">Row</span><span class="o">(</span><span class="n">p</span><span class="k">:</span> <span class="kt">Double</span><span class="o">,</span> <span class="n">l</span><span class="k">:</span> <span class="kt">Double</span><span class="o">)</span> <span class="k">=&gt;</span> <span class="o">(</span><span class="n">p</span><span class="o">,</span> <span class="n">l</span><span class="o">)}</span>
<span class="c1">// compute confusion matrix</span>
<span class="k">val</span> <span class="n">metrics</span> <span class="k">=</span> <span class="k">new</span> <span class="nc">MulticlassMetrics</span><span class="o">(</span><span class="n">predictionsAndLabels</span><span class="o">)</span>
<span class="n">println</span><span class="o">(</span><span class="n">metrics</span><span class="o">.</span><span class="n">confusionMatrix</span><span class="o">)</span>
<span class="c1">// the Iris DataSet has three classes</span>
<span class="k">val</span> <span class="n">numClasses</span> <span class="k">=</span> <span class="mi">3</span>
<span class="n">println</span><span class="o">(</span><span class="s">&quot;label\tfpr\n&quot;</span><span class="o">)</span>
<span class="o">(</span><span class="mi">0</span> <span class="n">until</span> <span class="n">numClasses</span><span class="o">).</span><span class="n">foreach</span> <span class="o">{</span> <span class="n">index</span> <span class="k">=&gt;</span>
<span class="k">val</span> <span class="n">label</span> <span class="k">=</span> <span class="n">index</span><span class="o">.</span><span class="n">toDouble</span>
<span class="n">println</span><span class="o">(</span><span class="n">label</span> <span class="o">+</span> <span class="s">&quot;\t&quot;</span> <span class="o">+</span> <span class="n">metrics</span><span class="o">.</span><span class="n">falsePositiveRate</span><span class="o">(</span><span class="n">label</span><span class="o">))</span>
<span class="o">}</span></code></pre></div>
</div>
<div data-lang="java">
<p>Refer to the <a href="api/java/org/apache/spark/ml/classification/OneVsRest.html">Java API docs</a> for more details.</p>
<div class="highlight"><pre><code class="language-java" data-lang="java"><span class="kn">import</span> <span class="nn">org.apache.spark.SparkConf</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.api.java.JavaSparkContext</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.LogisticRegression</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRest</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.ml.classification.OneVsRestModel</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.evaluation.MulticlassMetrics</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.linalg.Matrix</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.regression.LabeledPoint</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.mllib.util.MLUtils</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.rdd.RDD</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.DataFrame</span><span class="o">;</span>
<span class="kn">import</span> <span class="nn">org.apache.spark.sql.SQLContext</span><span class="o">;</span>
<span class="n">SparkConf</span> <span class="n">conf</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SparkConf</span><span class="o">().</span><span class="na">setAppName</span><span class="o">(</span><span class="s">&quot;JavaOneVsRestExample&quot;</span><span class="o">);</span>
<span class="n">JavaSparkContext</span> <span class="n">jsc</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">JavaSparkContext</span><span class="o">(</span><span class="n">conf</span><span class="o">);</span>
<span class="n">SQLContext</span> <span class="n">jsql</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">SQLContext</span><span class="o">(</span><span class="n">jsc</span><span class="o">);</span>
<span class="n">RDD</span><span class="o">&lt;</span><span class="n">LabeledPoint</span><span class="o">&gt;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">MLUtils</span><span class="o">.</span><span class="na">loadLibSVMFile</span><span class="o">(</span><span class="n">jsc</span><span class="o">.</span><span class="na">sc</span><span class="o">(),</span>
<span class="s">&quot;data/mllib/sample_multiclass_classification_data.txt&quot;</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">dataFrame</span> <span class="o">=</span> <span class="n">jsql</span><span class="o">.</span><span class="na">createDataFrame</span><span class="o">(</span><span class="n">data</span><span class="o">,</span> <span class="n">LabeledPoint</span><span class="o">.</span><span class="na">class</span><span class="o">);</span>
<span class="n">DataFrame</span><span class="o">[]</span> <span class="n">splits</span> <span class="o">=</span> <span class="n">dataFrame</span><span class="o">.</span><span class="na">randomSplit</span><span class="o">(</span><span class="k">new</span> <span class="kt">double</span><span class="o">[]</span> <span class="o">{</span><span class="mf">0.7</span><span class="o">,</span> <span class="mf">0.3</span><span class="o">},</span> <span class="mi">12345</span><span class="o">);</span>
<span class="n">DataFrame</span> <span class="n">train</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">0</span><span class="o">];</span>
<span class="n">DataFrame</span> <span class="n">test</span> <span class="o">=</span> <span class="n">splits</span><span class="o">[</span><span class="mi">1</span><span class="o">];</span>
<span class="c1">// instantiate the One Vs Rest Classifier</span>
<span class="n">OneVsRest</span> <span class="n">ovr</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">OneVsRest</span><span class="o">().</span><span class="na">setClassifier</span><span class="o">(</span><span class="k">new</span> <span class="nf">LogisticRegression</span><span class="o">());</span>
<span class="c1">// train the multiclass model</span>
<span class="n">OneVsRestModel</span> <span class="n">ovrModel</span> <span class="o">=</span> <span class="n">ovr</span><span class="o">.</span><span class="na">fit</span><span class="o">(</span><span class="n">train</span><span class="o">.</span><span class="na">cache</span><span class="o">());</span>
<span class="c1">// score the model on test data</span>
<span class="n">DataFrame</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">ovrModel</span>
<span class="o">.</span><span class="na">transform</span><span class="o">(</span><span class="n">test</span><span class="o">)</span>
<span class="o">.</span><span class="na">select</span><span class="o">(</span><span class="s">&quot;prediction&quot;</span><span class="o">,</span> <span class="s">&quot;label&quot;</span><span class="o">);</span>
<span class="c1">// obtain metrics</span>
<span class="n">MulticlassMetrics</span> <span class="n">metrics</span> <span class="o">=</span> <span class="k">new</span> <span class="nf">MulticlassMetrics</span><span class="o">(</span><span class="n">predictions</span><span class="o">);</span>
<span class="n">Matrix</span> <span class="n">confusionMatrix</span> <span class="o">=</span> <span class="n">metrics</span><span class="o">.</span><span class="na">confusionMatrix</span><span class="o">();</span>
<span class="c1">// output the Confusion Matrix</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;Confusion Matrix&quot;</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="n">confusionMatrix</span><span class="o">);</span>
<span class="c1">// compute the false positive rate per label</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">();</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">(</span><span class="s">&quot;label\tfpr\n&quot;</span><span class="o">);</span>
<span class="c1">// the Iris DataSet has three classes</span>
<span class="kt">int</span> <span class="n">numClasses</span> <span class="o">=</span> <span class="mi">3</span><span class="o">;</span>
<span class="k">for</span> <span class="o">(</span><span class="kt">int</span> <span class="n">index</span> <span class="o">=</span> <span class="mi">0</span><span class="o">;</span> <span class="n">index</span> <span class="o">&lt;</span> <span class="n">numClasses</span><span class="o">;</span> <span class="n">index</span><span class="o">++)</span> <span class="o">{</span>
<span class="kt">double</span> <span class="n">label</span> <span class="o">=</span> <span class="o">(</span><span class="kt">double</span><span class="o">)</span> <span class="n">index</span><span class="o">;</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">print</span><span class="o">(</span><span class="n">label</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">print</span><span class="o">(</span><span class="s">&quot;\t&quot;</span><span class="o">);</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">print</span><span class="o">(</span><span class="n">metrics</span><span class="o">.</span><span class="na">falsePositiveRate</span><span class="o">(</span><span class="n">label</span><span class="o">));</span>
<span class="n">System</span><span class="o">.</span><span class="na">out</span><span class="o">.</span><span class="na">println</span><span class="o">();</span>
<span class="o">}</span></code></pre></div>
</div>
</div>
</div> <!-- /container -->
<script src="js/vendor/jquery-1.8.0.min.js"></script>
<script src="js/vendor/bootstrap.min.js"></script>
<script src="js/vendor/anchor.min.js"></script>
<script src="js/main.js"></script>
<!-- MathJax Section -->
<script type="text/x-mathjax-config">
MathJax.Hub.Config({
TeX: { equationNumbers: { autoNumber: "AMS" } }
});
</script>
<script>
// Note that we load MathJax this way to work with local file (file://), HTTP and HTTPS.
// We could use "//cdn.mathjax...", but that won't support "file://".
(function(d, script) {
script = d.createElement('script');
script.type = 'text/javascript';
script.async = true;
script.onload = function(){
MathJax.Hub.Config({
tex2jax: {
inlineMath: [ ["$", "$"], ["\\\\(","\\\\)"] ],
displayMath: [ ["$$","$$"], ["\\[", "\\]"] ],
processEscapes: true,
skipTags: ['script', 'noscript', 'style', 'textarea', 'pre']
}
});
};
script.src = ('https:' == document.location.protocol ? 'https://' : 'http://') +
'cdn.mathjax.org/mathjax/latest/MathJax.js?config=TeX-AMS-MML_HTMLorMML';
d.getElementsByTagName('head')[0].appendChild(script);
}(document));
</script>
</body>
</html>